base.py 24.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
14
from ..wrapped_decorator import signature_safe_contextmanager, wrap_decorator
S
songyouwei 已提交
15
import decorator
16
import contextlib
17 18
import functools
import inspect
19
import sys
20 21 22
import numpy as np
from paddle.fluid import core
from paddle.fluid import framework
H
hong 已提交
23
from paddle.fluid.multiprocess_utils import CleanupFuncRegistrar
M
minqiyang 已提交
24
from .tracer import Tracer
Z
Zeng Jinle 已提交
25
import logging
J
Jiabin Yang 已提交
26
import objgraph
27
from ..data_feeder import convert_dtype
28

29
__all__ = [
30 31
    'no_grad', 'no_grad_', 'grad', 'guard', 'enable_dygraph', 'disable_dygraph',
    'enabled', 'to_variable'
32
]
33 34


35 36 37 38 39 40 41 42 43 44 45
def _switch_to_static_graph_(func):
    def __impl__(*args, **kwargs):
        with framework._dygraph_guard(None):
            return func(*args, **kwargs)

    return __impl__


switch_to_static_graph = wrap_decorator(_switch_to_static_graph_)


46 47 48 49 50 51
@signature_safe_contextmanager
def program_desc_tracing_guard(enable):
    tracer = framework._dygraph_tracer()
    if tracer:
        original_val = tracer._enable_program_desc_tracing
        tracer._enable_program_desc_tracing = enable
52 53 54 55 56
    try:
        yield
    finally:
        if tracer:
            tracer._enable_program_desc_tracing = original_val
57 58


59 60 61
_functional_dygraph_context_manager = None


62 63
@signature_safe_contextmanager
def param_guard(parameters):
64
    # Note: parameters is a reference of self._parameters or self._buffers
65 66 67 68
    if not framework.in_dygraph_mode() and parameters:
        origin_parameters = parameters.copy()
        for name, var_base in parameters.items():
            if isinstance(var_base, core.VarBase):
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
                # Convert ParamBase into Parameter with same attributes in dy2stat.
                if isinstance(var_base, framework.ParamBase):
                    new_var = var_base._to_static_var(to_parameter=True)
                else:
                    # Check whether has been created before.
                    if var_base.name in var_base.block.vars:
                        new_var = var_base.block.vars[var_base.name]
                    # Note(Aurelius84): Convert VarBase in self._buffers into Variabe with
                    # same attributes and set persistable=True to allow saving this var.
                    # Because users can create a VarBase in `__init__`  like a
                    # `mask` Tensor or `hidden_0` in RNN layers, which is equivalent to a Parameter
                    # and necessary for inferring. It will be pruned if it's not necessary for inferring.
                    else:
                        new_var = var_base._to_static_var(
                            to_parameter=False, persistable=True)
84 85 86 87 88 89 90
                parameters[name] = new_var
        yield
        parameters.update(origin_parameters)
    else:
        yield


91
def enabled():
92 93 94
    """
    This function checks whether the program runs in dynamic graph mode or not.
    You can enter dynamic graph mode with :ref:`api_fluid_dygraph_guard` api,
95 96
    or enable and disable dynamic graph mode with :ref:`api_fluid_dygraph_enable_dygraph`
    and :ref:`api_fluid_dygraph_disable_dygraph` api .
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114

    **Note**:
        ``fluid.dygraph.enabled`` is the alias of ``fluid.in_dygraph_mode``, and
        ``fluid.in_dygraph_mode`` is recommended to use.

    Returns:
        bool: Whether the program is running in dynamic graph mode.

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            fluid.enable_dygraph()  # Now we are in dygragh mode
            print(fluid.dygraph.enabled())  # True
            fluid.disable_dygraph()
            print(fluid.dygraph.enabled())  # False
    """
L
lujun 已提交
115
    return framework.in_dygraph_mode()
116 117


118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
def enable_dygraph(place=None):
    """
    This function enables dynamic graph mode.

    Parameters:
        place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place to execute dygraph.
            If None, the running place will be determined according to the way of paddle compilation. Default: None

    return:
        None

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            fluid.enable_dygraph()  # Now we are in dygragh mode
            print(fluid.in_dygraph_mode())  # True
            fluid.disable_dygraph()
            print(fluid.in_dygraph_mode())  # False
    """
    global _functional_dygraph_context_manager
S
songyouwei 已提交
140 141 142
    if _functional_dygraph_context_manager is None:
        _functional_dygraph_context_manager = guard(place=place)
        _functional_dygraph_context_manager.__enter__()
143

H
hong 已提交
144 145 146
        # call disable_dygraph when Python exit
        CleanupFuncRegistrar.register(disable_dygraph)

147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170

def disable_dygraph():
    """
    This function disables dynamic graph mode.

    return:
        None

    Examples:
        .. code-block:: python

            import paddle.fluid as fluid

            fluid.enable_dygraph()  # Now we are in dygragh mode
            print(fluid.in_dygraph_mode())  # True
            fluid.disable_dygraph()
            print(fluid.in_dygraph_mode())  # False
    """
    global _functional_dygraph_context_manager
    if _functional_dygraph_context_manager is not None:
        _functional_dygraph_context_manager.__exit__(*sys.exc_info())
        _functional_dygraph_context_manager = None


171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
    tracer = framework._dygraph_tracer()
    if tracer:
        mode = tracer._train_mode
        tracer._train_mode = is_train
        try:
            yield
        finally:
            tracer._train_mode = mode
    else:
        yield


def no_grad(func=None):
    """
    :api_attr: imperative

    Create a context which disables dygraph gradient calculation.
    In this mode, the result of every computation will have `stop_gradient=True`.

    Also functions as a decorator. (Make sure to instantiate without parenthesis.)

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        # use as generator

        data = np.array([[2, 3], [4, 5]]).astype('float32')
        with fluid.dygraph.guard():
            l0 = fluid.Linear(2, 2)  # l0.weight.gradient() is None
            l1 = fluid.Linear(2, 2)
            with fluid.dygraph.no_grad():
                # l1.weight.stop_gradient is False
                tmp = l1.weight * 2  # tmp.stop_gradient is True
            x = fluid.dygraph.to_variable(data)
            y = l0(x) + tmp
            o = l1(y)
            o.backward()
            print(tmp.gradient() is None)  # True
            print(l0.weight.gradient() is None)  # False

        # use as decorator

        @fluid.dygraph.no_grad
        def test_layer():
            with fluid.dygraph.guard():
                inp = np.ones([3, 1024], dtype='float32')
                t = fluid.dygraph.base.to_variable(inp)
                linear1 = fluid.Linear(1024, 4, bias_attr=False)
                linear2 = fluid.Linear(4, 4)
                ret = linear1(t)
                dy_ret = linear2(ret)

        test_layer()

    """
    if func is None:
        return _switch_tracer_mode_guard_(is_train=False)
    else:

        @decorator.decorator
        def __impl__(func, *args, **kwargs):
            with _switch_tracer_mode_guard_(is_train=False):
                return func(*args, **kwargs)

        return __impl__(func)


class no_grad_:
245
    """
246 247
    :api_attr: imperative

248
    Create a context which disables dygraph gradient calculation.
249 250
    In this mode, the result of every computation will have `stop_gradient` set
    to `True`.
251

252
    Also functions as a decorator. (Make sure to use an instance.)
253 254 255 256 257 258

    Examples:

     .. code-block:: python

        import numpy as np
259
        import paddle
260

261
        paddle.disable_static()
262

263 264 265
        # use as generator

        data = np.array([[2, 3], [4, 5]]).astype('float32')
266 267 268
        l0 = paddle.nn.Linear(2, 2)  # l0.weight.gradient() is None
        l1 = paddle.nn.Linear(2, 2)
        with paddle.no_grad():
269 270
            # l1.weight.stop_gradient is False
            tmp = l1.weight * 2  # tmp.stop_gradient is True
271
        x = paddle.to_tensor(data)
272 273 274 275 276
        y = l0(x) + tmp
        o = l1(y)
        o.backward()
        print(tmp.gradient() is None)  # True
        print(l0.weight.gradient() is None)  # False
277 278 279

        # use as decorator

280
        @paddle.no_grad()
281
        def test_layer():
282
            inp = np.ones([3, 1024], dtype='float32')
283 284 285
            t = paddle.to_tensor(inp)
            linear1 = paddle.nn.Linear(1024, 4, bias_attr=False)
            linear2 = paddle.nn.Linear(4, 4)
286 287
            ret = linear1(t)
            dy_ret = linear2(ret)
288 289 290 291

        test_layer()
    """

292
    def __call__(self, func):
S
songyouwei 已提交
293
        @decorator.decorator
294 295
        def _decorate_function(func, *args, **kwargs):
            with self:
296
                return func(*args, **kwargs)
297

298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319
        @decorator.decorator
        def _decorate_generator(func, *args, **kwargs):
            gen = func(*args, **kwargs)
            with self:
                for x in gen:
                    yield x

        if inspect.isgeneratorfunction(func):
            return _decorate_generator(func)
        else:
            return _decorate_function(func)

    def __enter__(self):
        tracer = framework._dygraph_tracer()
        if tracer:
            self.orig = tracer._train_mode
            tracer._train_mode = False

    def __exit__(self, *args):
        tracer = framework._dygraph_tracer()
        if tracer:
            tracer._train_mode = self.orig
320 321


S
rename  
sneaxiy 已提交
322
@signature_safe_contextmanager
P
Paddle CI 已提交
323
def guard(place=None):
324
    """
325 326
    :api_attr: imperative

327
    This context will create a dygraph context for dygraph to run, using python ``with`` statement.
328

329 330 331
    Parameters:
        place(fluid.CPUPlace or fluid.CUDAPlace, optional): Place to execute dygraph. 
            If None, the running place will be determined according to the way of paddle compilation. Default: None
332 333 334 335 336 337 338 339 340 341 342 343

    return:
        None

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

        with fluid.dygraph.guard():
344
            inp = np.ones([3, 1024], dtype='float32')
345
            t = fluid.dygraph.base.to_variable(inp)
346 347 348 349
            linear1 = fluid.Linear(1024, 4, bias_attr=False)
            linear2 = fluid.Linear(4, 4)
            ret = linear1(t)
            dy_ret = linear2(ret)
350 351

    """
352 353
    train = framework.Program()
    startup = framework.Program()
J
Jiabin Yang 已提交
354
    tracer = Tracer()
355
    VarBase = core.VarBase
M
minqiyang 已提交
356

357 358 359 360 361
    if place is not None:
        expected_place = place
    else:
        expected_place = framework._current_expected_place()
    tracer._expected_place = expected_place
M
minqiyang 已提交
362

363 364
    with framework.program_guard(train, startup):
        with framework.unique_name.guard():
L
lujun 已提交
365 366
            with framework._dygraph_guard(tracer):
                with framework._dygraph_place_guard(place):
P
Paddle CI 已提交
367
                    yield
368 369


370
def _print_debug_msg(parameter_list, limit=5, is_test=False):
Z
Zeng Jinle 已提交
371 372 373 374 375 376
    if not core._is_dygraph_debug_enabled():
        logging.warn(
            'Debug mode is not enabled. Please set FLAGS_dygraph_debug=1 to enable debug'
        )
        return
    unique_name_size = len(framework.unique_name.generator.ids)
377
    tracer_var_size = len(parameter_list)
Z
Zeng Jinle 已提交
378
    alive_cpp_var_size = len(core.VarBase._alive_vars())
J
Jiabin Yang 已提交
379 380 381 382 383 384 385
    if not is_test:
        logging.warn(
            'unique_name num: {}, tracer vars num: {}, alive cpp vars num: {}'
            .format(unique_name_size, tracer_var_size, alive_cpp_var_size))
        objgraph.show_growth(limit=limit)
    else:
        return unique_name_size, tracer_var_size, alive_cpp_var_size
Z
Zeng Jinle 已提交
386 387


388 389 390 391
@framework.dygraph_only
def grad(outputs,
         inputs,
         grad_outputs=None,
Z
Zeng Jinle 已提交
392
         retain_graph=None,
393
         create_graph=False,
Z
Zeng Jinle 已提交
394 395
         only_inputs=True,
         allow_unused=False,
396
         no_grad_vars=None):
Z
Zeng Jinle 已提交
397 398 399 400 401 402 403
    ''' 
    .. note::
        **This API is ONLY available in Dygraph mode.**

    This API computes the sum of gradients of `outputs` with respect to each `inputs` .

    Parameters:
404 405 406 407
        outputs (Tensor|list(Tensor)|tuple(Tensor)): the output Tensor or 
            Tensor list/tuple of the graph to compute gradients.
        inputs (Tensor|list(Tensor)|tuple(Tensor)): the input Tensor or 
            Tensor list/tuple of the graph to compute gradients. The returned
Z
Zeng Jinle 已提交
408
            values of this API are the gradients of `inputs` . 
409
        grad_outputs (Tensor|list(Tensor|None)|tuple(Tensor|None), optional): 
Z
Zeng Jinle 已提交
410 411 412 413 414 415
            initial gradient values of `outputs` . If `grad_outputs` is None, 
            the initial gradient values of `outputs` would be Tensors filled with 1; 
            if `grad_outputs` is not None, it must have the same length as `outputs` , 
            and in this case, the initial gradient value of the i-th `outputs` would
            be: (1) a Tensor filled with 1 when the i-th element of `grad_outputs` 
            is None; (2) the i-th element of `grad_outputs` when the i-th element of
416
            `grad_outputs` is a Tensor. Default None.
Z
Zeng Jinle 已提交
417 418 419 420 421 422 423 424 425 426 427
        retain_graph (bool, optional): whether to retain the forward graph which 
            is used to calculate the gradient. When it is True, the graph would 
            be retained, in which way users can calculate backward twice for the 
            same graph. When it is False, the graph would be freed. Default None,
            which means it is equal to `create_graph` . 
        create_graph (bool, optional): whether to create the gradient graphs of
            the computing process. When it is True, higher order derivatives are
            supported to compute; when it is False, the gradient graphs of the
            computing process would be discarded. Default False.
        only_inputs (bool, optional): whether to only compute the gradients of
            `inputs` . If it is False, the gradients of all remaining leaf 
428
            Tensors in the graph would be also computed and accumulated. 
Z
Zeng Jinle 已提交
429 430 431 432
            If it is True, only the gradients of `inputs` would be computed.
            Default True. only_inputs=False is under development, and it is
            not supported yet.    
        allow_unused (bool, optional): whether to raise error or return None if some 
433
            Tensors of `inputs` are unreachable in the graph. If some Tensors of 
Z
Zeng Jinle 已提交
434 435 436
            `inputs` are unreachable in the graph (i.e., their gradients are None),  
            error would be raised if allow_unused=False, or None would be returned as
            their gradients if allow_unused=True. Default False.
437 438
        no_grad_vars (Tensor|list(Tensor)|tuple(Tensor)|set(Tensor), optional): 
            the Tensors whose gradients are not needed to compute. Default None.
Z
Zeng Jinle 已提交
439 440

    Returns:
441 442
        tuple: a tuple of Tensors, whose length is the same as the Tensor number 
        inside `inputs`, and the i-th returned Tensor is the sum of gradients of 
Z
Zeng Jinle 已提交
443 444 445 446 447
        `outputs` with respect to the i-th `inputs`.

    Examples 1:
        .. code-block:: python

448 449
            import paddle
            paddle.disable_static()
Z
Zeng Jinle 已提交
450 451

            def test_dygraph_grad(create_graph):
452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477
                x = paddle.ones(shape=[1], dtype='float32')
                x.stop_gradient = False
                y = x * x

                # Since y = x * x, dx = 2 * x
                dx = paddle.grad(
                        outputs=[y],
                        inputs=[x],
                        create_graph=create_graph,
                        retain_graph=True)[0]

                z = y + dx

                # If create_graph = False, the gradient of dx
                # would not be backpropagated. Therefore,
                # z = x * x + dx, and x.gradient() = 2 * x = 2.0

                # If create_graph = True, the gradient of dx
                # would be backpropagated. Therefore,
                # z = x * x + dx = x * x + 2 * x, and
                # x.gradient() = 2 * x + 2 = 4.0

                z.backward()
                return x.gradient()

            print(test_dygraph_grad(create_graph=False)) # [2.]
Z
Zeng Jinle 已提交
478 479 480 481 482
            print(test_dygraph_grad(create_graph=True)) # [4.]

    Examples 2:
        .. code-block:: python

483 484
            import paddle
            paddle.disable_static()
Z
Zeng Jinle 已提交
485 486

            def test_dygraph_grad(grad_outputs=None):
487
                x = paddle.fill_constant(shape=[1], value=2.0, dtype='float32')
Z
Zeng Jinle 已提交
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502
                x.stop_gradient = False

                y1 = x * x
                y2 = x * 3 

                # If grad_outputs=None, dy1 = [1], dy2 = [1].
                # If grad_outputs=[g1, g2], then:
                #    - dy1 = [1] if g1 is None else g1
                #    - dy2 = [1] if g2 is None else g2

                # Since y1 = x * x, dx = 2 * x * dy1.
                # Since y2 = x * 3, dx = 3 * dy2.
                # Therefore, the final result would be:
                # dx = 2 * x * dy1 + 3 * dy2 = 4 * dy1 + 3 * dy2.

503
                dx = paddle.grad(
Z
Zeng Jinle 已提交
504 505 506 507 508 509
                    outputs=[y1, y2], 
                    inputs=[x],
                    grad_outputs=grad_outputs)[0]

                return dx.numpy()

510
            grad_value = paddle.fill_constant(shape=[1], value=4.0, dtype='float32')
Z
Zeng Jinle 已提交
511 512 513 514 515

            # dy1 = [1], dy2 = [1]
            print(test_dygraph_grad(None)) # [7.]

            # dy1 = [1], dy2 = [4]
516
            print(test_dygraph_grad([None, grad_value])) # [16.]
Z
Zeng Jinle 已提交
517 518

            # dy1 = [4], dy2 = [1]
519
            print(test_dygraph_grad([grad_value, None])) # [19.]
Z
Zeng Jinle 已提交
520 521

            # dy1 = [3], dy2 = [4]
522 523
            grad_y1 = paddle.fill_constant(shape=[1], value=3.0, dtype='float32')
            print(test_dygraph_grad([grad_y1, grad_value])) # [24.]
Z
Zeng Jinle 已提交
524 525
	'''

526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
    def check_in_out(in_out_list, name):
        assert in_out_list is not None, "{} should not be None".format(name)

        if isinstance(in_out_list, (list, tuple)):
            assert len(in_out_list) > 0, "{} cannot be empty".format(name)
            for each_var in in_out_list:
                assert isinstance(
                    each_var,
                    core.VarBase), "Elements of {} must be Variable".format(
                        name)
            return in_out_list
        else:
            assert isinstance(
                in_out_list,
                core.VarBase), "{} must be Variable or list of Variable".format(
                    name)
            return [in_out_list]

    outputs = check_in_out(outputs, 'outputs')
    inputs = check_in_out(inputs, 'inputs')

    if grad_outputs is not None:
        if not isinstance(grad_outputs, (list, tuple)):
            grad_outputs = [grad_outputs]

        for each_var in grad_outputs:
            if each_var is not None:
                assert isinstance(
                    each_var, core.VarBase
                ), "grad_outputs must be None, a Variable or a list containing None or Variables"
    else:
        grad_outputs = []

    if len(grad_outputs) > 0:
        assert len(grad_outputs) == len(
            outputs), "The length of grad_outputs must be equal to outputs"

Z
Zeng Jinle 已提交
563 564 565 566 567 568 569
    if no_grad_vars is None:
        no_grad_vars = []
    elif isinstance(no_grad_vars, core.VarBase):
        no_grad_vars = [no_grad_vars]
    elif isinstance(no_grad_vars, (list, tuple, set)):
        no_grad_vars = list(no_grad_vars)
        for var in no_grad_vars:
570
            assert isinstance(
Z
Zeng Jinle 已提交
571
                var, core.VarBase), "no_grad_vars can only contains Variable"
572 573
    else:
        raise AssertionError(
Z
Zeng Jinle 已提交
574
            "no_grad_vars must be None, Variable or list/tuple/set of Variables")
575 576 577

    assert isinstance(create_graph, bool), "create_graph must be True or False"

Z
Zeng Jinle 已提交
578 579 580 581 582 583 584 585 586 587 588
    if retain_graph is None:
        retain_graph = create_graph

    assert isinstance(retain_graph,
                      bool), "retain_graph must be None, True or False"

    assert isinstance(allow_unused, bool), "allow_unused must be True or False"

    assert isinstance(only_inputs, bool), "only_inputs must be True or False"
    assert only_inputs, "only_inputs=False is not supported yet"

589 590
    place = core.Place()
    place.set_place(framework._current_expected_place())
591 592 593
    return core.dygraph_partial_grad(inputs, outputs, grad_outputs,
                                     no_grad_vars, place, create_graph,
                                     retain_graph, allow_unused, only_inputs)
594 595


596
@framework.dygraph_only
597
def to_variable(value, name=None, zero_copy=None, dtype=None):
598
    """
599 600
    :api_attr: imperative

601
    The API will create a ``Variable`` or ``ComplexVariable`` object from 
602
    tuple, list, numpy\.ndarray, Variable or ComplexVariable object.
603

604
    Parameters:
605 606 607 608 609
        value(tuple|list|ndarray|Variable|Tensor|ComplexVariable): Initial data. 
            Can be a list, tuple, NumPy ndarray, Variable, Tensor, ComplexVariable. 
            The shape can be multi-dimensional. The data type is one of 
            numpy\.{float16, float32, float64, int16, int32, int64, 
            uint8, uint16, complex64, complex128}.
610 611 612 613 614 615
        name(str, optional): The default value is None. Normally there is no 
            need for user to set this property. For more information, please 
            refer to :ref:`api_guide_Name` .
        zero_copy(bool, optional): Whether to share memory with the input numpy 
            array. This parameter only works with CPUPlace and will be set to 
            True when it is None. Default: None.
616 617 618
        dtype(str, optional): The desired data type of returned ``Variable`` .
            Can be 'bool' , 'float16' , 'float32' , 'float64' , 'int8' , 'int16' , 
            'int32' , 'int64' , 'uint8' . Default: None.
619

620
    Returns:
621 622 623 624
        Variable or ComplexVariable: If ``value`` is a tuple/list/numpy\.ndarray object, 
            return ``Tensor`` created from the corresponding numpy\.ndarray object, which has 
            same data type and shape with ``value``. If ``value`` is a Variable or ComplexVariable 
            object, just return ``value``.
625

626 627 628 629 630 631 632 633

    Examples:

     .. code-block:: python

        import numpy as np
        import paddle.fluid as fluid

634
        with fluid.dygraph.guard(fluid.CPUPlace()):
635
            x = np.ones([2, 2], np.float32)
636 637 638
            y = fluid.dygraph.to_variable(x, zero_copy=False)
            x[0][0] = -1
            y[0][0].numpy()  # array([1.], dtype=float32)
639
            y = fluid.dygraph.to_variable(x)
640 641
            x[0][0] = 0
            y[0][0].numpy()  # array([0.], dtype=float32)
642 643 644 645
            c = np.array([2+1j, 2])
            z = fluid.dygraph.to_variable(c)
            z.numpy() # array([2.+1.j, 2.+0.j])
            z.dtype # 'complex128'
646 647 648 649 650 651 652

            y = fluid.dygraph.to_variable([[0.1, 1.2], [2.2, 3.1], [4.9, 5.2]])
            y.shape     # [3L, 2L]

            y = fluid.dygraph.to_variable(((0.1, 1.2), (2.2, 3.1), (4.9, 5.2)), dtype='int32')
            y.shape     # [3L, 2L]

653
    """
654 655 656 657 658 659 660 661 662 663 664 665
    support_type = (list, tuple, np.ndarray, core.VarBase, framework.Variable,
                    framework.ComplexVariable, core.Tensor, core.LoDTensor)
    if not isinstance(value, support_type):
        raise TypeError(
            "The type of 'value' in fluid.dygraph.to_variable must be %s, but received %s."
            % (support_type, type(value)))
    if isinstance(value, (core.VarBase, framework.Variable,
                          framework.ComplexVariable)):
        return value
    elif isinstance(value, (core.Tensor, core.LoDTensor)):
        return core.VarBase(value)
    else:
666 667 668 669 670 671
        if isinstance(framework._current_expected_place(),
                      framework.core.CPUPlace):
            if zero_copy is None:
                zero_copy = True
        else:
            assert not zero_copy, "zero_copy mode can only be used with CPUPlace"
672 673 674 675 676 677 678 679 680

        if not isinstance(value, np.ndarray):
            value = np.array(value)

        if dtype is not None:
            dtype = convert_dtype(dtype)
            if value.dtype != dtype:
                value = value.astype(dtype)

681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704
        if np.iscomplexobj(value):
            if not name:
                name = framework.unique_name.generate('_generated_var')
            real_var = core.VarBase(
                value=value.real,
                place=framework._current_expected_place(),
                persistable=False,
                zero_copy=zero_copy,
                name=name + ".real")
            imag_var = core.VarBase(
                value=value.imag,
                place=framework._current_expected_place(),
                persistable=False,
                zero_copy=zero_copy,
                name=name + ".imag")
            return framework.ComplexVariable(real_var, imag_var)
        else:
            py_var = core.VarBase(
                value=value,
                place=framework._current_expected_place(),
                persistable=False,
                zero_copy=zero_copy,
                name=name if name else '')
            return py_var