program_translator.py 24.0 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import gast
17
import inspect
18
import warnings
19 20
import textwrap
import threading
21
import collections
22
import numpy as np
23
from paddle.fluid import core, scope_guard
24
from paddle.fluid import framework
25
from paddle.fluid import executor
26 27
from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers
28 29
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
30
from paddle.fluid.dygraph.base import switch_to_static_graph
31 32
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
33 34
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
35
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
36
from paddle.fluid.dygraph.base import param_guard
37
from paddle.fluid.data_feeder import check_type
38
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
39
from paddle.fluid.dygraph.dygraph_to_static.origin_info import attach_origin_info, create_and_update_origin_info_map
40
from paddle.fluid.dygraph.dygraph_to_static.origin_info import update_op_callstack_with_origin_info
41
from paddle.fluid.dygraph.dygraph_to_static.error import attach_error_data, ERROR_DATA
42

43
__all__ = ['ProgramTranslator', 'convert_to_static']
44 45 46 47 48 49 50 51


class FunctionCache(object):
    """
    Caches the transformed functions to avoid redundant conversions of the same function.
    """

    def __init__(self):
52 53 54 55 56
        # Caches the converted static functions. {dygraph_func: static_func}
        self._converted_static_func_caches = dict()
        # Caches the converted ast node for same source code. {source_code: ast_root}
        self._code_to_ast_caches = dict()
        self._dygraph_to_static = DygraphToStaticAst()
57

58 59 60 61 62 63
    def convert_with_cache(self, func):
        """
        Returns the cached static function or converts it when first encounters the function.
        """
        # If hit cache, return it directly.
        static_func = self._converted_static_func_caches.get(func, None)
64 65

        if static_func is None:
66 67
            static_func = self._convert(func)
            self._converted_static_func_caches[func] = static_func
68 69 70

        return static_func

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
    def _convert(self, func):
        """
        Converts dygraph function into static function. For two functions with same dedent code,
        the second function will reuse the transformed ast node of previous one.

        For example:
            # A.py
            def foo(x, y):
                z = x + y
                return z

            # B.py
            def foo(x, y):
                z = x + y
                return z

        If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo
        to speed up the conversion.
        """
        # Note: In Python2, it will raise OSError when inspect function
        # with decorator directly and function.__wrapped__ holds the actual function.
        func = getattr(func, '__wrapped__', func)
        source_code = func_to_source_code(func)
94 95 96 97 98

        # TODO(liym27):
        #  Consider this case: source_code in self._code_to_ast_caches,
        #  but actually they are methods in different classes.
        #  Maybe use (__class__, source_code) as key
99 100 101 102
        if source_code in self._code_to_ast_caches:
            root_wrapper = self._code_to_ast_caches[source_code]
        else:
            root = gast.parse(source_code)
103
            root = attach_origin_info(root, func)
104 105
            root_wrapper = self._dygraph_to_static.get_static_ast(root)
            self._code_to_ast_caches[source_code] = root_wrapper
106

107 108
        # Get static function from AST
        static_func, file_name = ast_to_func(root_wrapper.node, func)
109 110

        create_and_update_origin_info_map(root_wrapper.node, static_func)
111
        return static_func
112 113

    def exist(self, func):
114
        return func in self._converted_static_func_caches
115 116


117 118 119 120
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()


121
def convert_to_static(function):
122
    """
123
    Transforms function of dygraph into static function using the cache mechanism.
124 125 126

    Args:
        function(callable): The function with dygraph layers that will be converted into static layers.
127 128
    """
    with _CACHE_LOCK:
129
        static_func = _FUNCTION_CACHE.convert_with_cache(function)
130 131 132
        return static_func


133 134 135 136 137
class FunctionSpec(object):
    def __init__(self, func, args, kwargs):
        self._dyfunc = func
        self._args = args
        self._kwargs = kwargs
138

139
        # TODO(liym27): func has multi layer decorator
140 141 142
        dyfunc = getattr(func, '__wrapped__', func)
        self._dyfunc_code = inspect.getsource(dyfunc)

143 144
    def is_method(self):
        return self._args and isinstance(self._args[0], layers.Layer)
145

146
    def parameters(self, include_sublayer=True):
147 148 149 150
        """
        Returns parameters of decorated layers. If set `include_sublayer` True,
        the parameters created in sub layers will be added.
        """
151
        params = collections.OrderedDict()
152
        if self.is_method():
153
            layer_instance = self._args[0]
154
            if include_sublayer:
155
                params = layer_instance.parameters()
156 157
                names = [p.name for p in params]
                params = collections.OrderedDict(zip(names, params))
158
            else:
159
                params = layer_instance._parameters
160 161
        return params

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177
    def buffers(self, include_sublayer=True):
        """
        Returns Variable buffers of decorated layers. If set `include_sublayer` True,
        the Variable buffers created in sub layers will be added.
        """
        buffers = collections.OrderedDict()
        if self.is_method():
            layer_instance = self._args[0]
            if include_sublayer:
                buffers = layer_instance.buffers()
                names = [buffer.name for buffer in buffers]
                buffers = collections.OrderedDict(zip(names, buffers))
            else:
                buffers = layer_instance._buffers
        return buffers

178 179 180 181
    @switch_to_static_graph
    def to_static_inputs(self, main_program):
        inputs = []
        block = main_program.global_block()
182
        for input_var in flatten(self.args):
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198
            if isinstance(input_var, np.ndarray):
                feed_layer = block.create_var(
                    name=unique_name.generate('feed'),
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    is_data=True,
                    need_check_feed=False)
            elif isinstance(input_var, core.VarBase):
                feed_layer = block.create_var(
                    name=input_var.name,
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    stop_gradient=input_var.stop_gradient,
                    need_check_feed=False)
            else:
                feed_layer = input_var
199

200
            inputs.append(feed_layer)
201 202
        # Restores the nested structure as self.args
        return pack_sequence_as(self.args, inputs)
203

204 205 206
    @property
    def dyfunc(self):
        return self._dyfunc
207

208 209 210 211 212 213 214 215
    @property
    def args(self):
        return self._args

    def __key(self):
        # Note: if dygraph function is a method of class,
        # consider instance info as hash key.
        if self.is_method():
216 217 218
            # NOTE: we can use Layer's (instance + function code) as hash key.
            # An instance will not hold two identical methods 
            return self._dyfunc_code, self._args[0]
219 220 221 222 223 224 225 226 227 228
        else:
            return self._dyfunc

    def __hash__(self):
        return hash(self.__key())

    def __eq__(self, other):
        return self.__key() == self.__key()


229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False


def in_declarative_mode():
    """
    Return a bool value that indicates whether running code under `@declarative`

    """
    return _in_declarative_mode_


@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):

    global _in_declarative_mode_
    original_val = _in_declarative_mode_
    _in_declarative_mode_ = is_declarative
    yield
    _in_declarative_mode_ = original_val


251 252 253 254 255 256 257
class ConcreteProgram(object):
    def __init__(self,
                 inputs,
                 outputs,
                 parameters,
                 func,
                 main_program,
258
                 startup_program=None):
259 260 261
        self.inputs = inputs
        self.outputs = outputs
        self.main_program = main_program
262
        self.startup_program = startup_program
263 264 265 266 267 268
        self.parameters = parameters
        self.func_spec = func

    @staticmethod
    @switch_to_static_graph
    def from_func_spec(func_spec):
269
        """
270 271
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.
272
        """
273
        # Transforms dygraph function into static function and caches it.
274
        dygraph_function = func_spec.dyfunc
275
        static_func = convert_to_static(dygraph_function)
276

277 278
        main_program, startup_program = framework.Program(), framework.Program()
        # Note: The random seed should be synchronized into cached program
279
        # if set in `fluid.dygraph_guard` because some ops rely on it, such as
280
        # `fluid.layers.dropout`.
281
        main_program.random_seed = framework.default_main_program().random_seed
282 283
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed
284

285
        with framework.program_guard(main_program, startup_program):
286 287 288 289
            with _switch_declarative_mode_guard_(is_declarative=True):
                # 1. Adds `fluid.data` layers for input if needed
                inputs = func_spec.to_static_inputs(main_program)

290 291 292
                # 2. Gets all ParamBases and buffered VarBases in the function
                all_parameters_and_buffers = list(func_spec.parameters().values(
                )) + list(func_spec.buffers().values())
293 294

                # 3. Builds program only once and returns the output Variables.
295 296
                with param_guard(func_spec.parameters(False)), param_guard(
                        func_spec.buffers(False)):
297 298 299 300 301 302 303
                    try:
                        outputs = static_func(*inputs)
                    except BaseException as e:
                        # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
                        attach_error_data(e)
                        raise

304 305 306
                if not isinstance(outputs,
                                  (tuple, list)) and outputs is not None:
                    outputs = [outputs]
307

308 309
        main_program = update_op_callstack_with_origin_info(main_program)

310 311 312
        return ConcreteProgram(
            inputs=inputs,
            outputs=outputs,
313
            parameters=all_parameters_and_buffers,
314
            func=dygraph_function,
315
            main_program=main_program,
316
            startup_program=startup_program)
317 318


319 320 321 322
class ProgramCache(object):
    """
    Wrapper class for the program functions defined by dygraph function.
    """
323

324
    def __init__(self):
325
        self._caches = collections.OrderedDict()
326

327 328 329
    def _build_once(self, func_spec):
        concrete_program = ConcreteProgram.from_func_spec(func_spec)
        return concrete_program, partial_program_from(concrete_program)
330

331 332 333 334 335 336 337 338
    def __getitem__(self, item):
        if not isinstance(item, FunctionSpec):
            raise ValueError(
                'type(item) should be FunctionSpec, but received %s' %
                type(item))
        if item not in self._caches:
            self._caches[item] = self._build_once(item)
        return self._caches[item]
339

340 341 342 343 344 345 346 347 348 349 350
    def get_program(self, item):
        if not isinstance(item, FunctionSpec):
            raise ValueError(
                "Input item's type should be FunctionSpec, but received %s" %
                type(item))
        if item not in self._caches:
            raise RuntimeError(
                "Failed to find program for input item, please decorate input function by `@declarative`."
            )
        return self._caches[item]

351 352 353 354 355 356
    def last(self):
        assert len(
            self._caches) >= 1, "No valid cached program in ProgramCache."
        key = next(reversed(self._caches.keys()))
        return key, self._caches[key]

357

358 359
def synchronized(func):
    func.__lock__ = threading.Lock()
360

361 362 363
    def lock_func(*args, **kwargs):
        with func.__lock__:
            return func(*args, **kwargs)
364

365
    return lock_func
366 367


368
class ProgramTranslator(object):
369
    """
370 371 372 373 374 375 376 377 378 379 380 381 382 383
    Class to translate dygraph function into static graph function. The object
    of this class is a singleton.

    Args:
        None.

    Returns:
        ProgramTranslator: the singleton object.

    Examples:
        .. code-block:: python

        import paddle.fluid as fluid

384
        # Two methods get same object because ProgramTranslator is a singleton
385 386 387
        fluid.dygraph.ProgramTranslator()
        fluid.dygraph.ProgramTranslator.get_instance()

388 389
    """

390
    _singleton_lock = threading.Lock()
391 392 393 394 395 396
    _instance = None

    @synchronized
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
397
            cls._instance._initialized = False
398 399 400 401 402
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
403 404
            with cls._singleton_lock:
                cls._instance = cls()
405 406 407 408 409
        return cls._instance

    @classmethod
    def reset(cls):
        if cls._instance is not None:
410
            cls._instance._initialized = False
411 412
            cls._instance.__init__()

413
    def __init__(self):
414
        # To make sure that calls __init__ only once.
415
        if self._initialized:
416
            return
417 418
        self._initialized = True
        self._program_cache = ProgramCache()
419 420
        self.enable_declarative = True

421
    def enable(self, enable_declarative):
422 423 424 425 426
        """
        Enable or disable the converting from imperative to declarative by
        ProgramTranslator globally.

        Args:
427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
            enable_declarative (bool): True or False to enable or disable declarative.

        Returns:
            None.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            @fluid.dygraph.jit.declarative
            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()
            prog_trans.enable(False)

            x = np.ones([1, 2])
L
liym27 已提交
451
            # The declarative is disabled so the func is run in dygraph
452 453
            with fluid.dygraph.guard():
                print(func(x).numpy()) # [[2. 2.]]
L
liym27 已提交
454

455
        """
456 457
        check_type(enable_declarative, "enable_declarative", bool,
                   "ProgramTranslator.enable")
458
        self.enable_declarative = enable_declarative
459

460 461
    def get_output(self, dygraph_func, *args, **kwargs):
        """
462 463 464 465 466 467
        Returns the output dygraph VarBase for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by declarative mode.

        Args:
            dygraph_func (callable): the dygraph function.
L
liym27 已提交
468
            *args, **kwargs : the input argument of dygraph_func.
469 470 471 472

        Returns:
            VarBase or tuple of VarBase: the dygraph VarBase containing digital
                result.
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

490 491 492 493
                with fluid.dygraph.guard():
                    x = np.ones([1, 2])
                    x_v = prog_trans.get_output(func, x)
                    print(x_v.numpy()) # [[0. 0.]]
494

495
        """
496 497 498
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
499
        if not self.enable_declarative:
500
            warnings.warn(
501 502
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
                "We will just return dygraph output.")
503 504
            return dygraph_func(*args, **kwargs)

505
        function_spec = FunctionSpec(dygraph_func, args, kwargs)
506 507
        concrete_program, partial_program_layer = self._program_cache[
            function_spec]
508 509

        if args and isinstance(args[0], layers.Layer):
510 511
            # Synchronize self.training attribute.
            partial_program_layer.training = args[0].training
512
            args = args[1:]
513 514 515 516 517 518 519 520 521
        try:
            return partial_program_layer(args)

        except BaseException as e:
            # NOTE:
            # 1. If e is raised in compile time, e should have been attached to ERROR_DATA before;
            # 2. If e raised in runtime, e should be attached to ERROR_DATA here.
            if not hasattr(e, ERROR_DATA):
                # runtime error
522
                attach_error_data(e, in_runtime=True)
523
            raise
524 525 526

    def get_func(self, dygraph_func):
        """
527 528 529 530 531 532 533 534 535 536 537
        Returns a callable function which converts imperative dygraph APIs of
        the input dygraph_func into declarative net-building APIs, which means
        it doesn't return immediate digital result as get_output does.
        Users should handle Program and Executor by themselves.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
            callable: converting imperative dygraph APIs into declarative
            net-building APIs.
538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                static_func = prog_trans.get_func(func)
                print(callable(static_func)) # True

558
        """
559 560 561
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
562
        if not self.enable_declarative:
563
            warnings.warn(
564
                "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
565
                "just return dygraph output.")
566
            return dygraph_func
567

568
        static_func = convert_to_static(dygraph_func)
569 570
        return static_func

571 572
    def get_program(self, dygraph_func, *args, **kwargs):
        """
573
        Returns the translated static program and input/output variables from
574 575 576 577 578 579 580 581 582 583 584 585 586
        dygraph function. The users can use the program to run by executor.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            tuple of (main_program, startup_program, inputs, outputs) whose
            types are (Program, Program, list of Variable, list of Variable).
            main_program: the converted main program.
            startup_program: the converted startup program.
            inputs: list of input Variables which need to be fed.
            outputs: list of output Variables which users can fetch.
587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                x = np.ones([1, 2])
                main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
                print([i.name for i in inputs])
607
                # ['feed_0'] the feed input variable name representing x
608 609 610
                print([o.name for o in outputs])
                # ['_generated_var_4'] the fetch output variable name representing x_v        

611
        """
612 613 614
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
615
        if not self.enable_declarative:
616
            warnings.warn(
617 618
                "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
                "We will just return dygraph output.")
619
            return dygraph_func(*args, **kwargs)
620

621 622
        func_spec = FunctionSpec(dygraph_func, args, kwargs)
        concrete_program, _ = self._program_cache[func_spec]
623 624 625 626 627 628 629 630 631 632
        # Note: concrete_program hold all input/output infos include non-Variable
        input_vars = [
            var for var in concrete_program.inputs
            if isinstance(var, framework.Variable)
        ]
        output_vars = [
            var for var in concrete_program.outputs
            if isinstance(var, framework.Variable)
        ]

633 634
        return concrete_program.main_program, \
               concrete_program.startup_program, \
635 636
               input_vars, \
               output_vars
637

638 639
    def get_code(self, dygraph_func):
        """
640 641 642 643 644 645
        Returns the translated static function string code from dygraph function.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666
            str: the string code of translated static function.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()

            code = prog_trans.get_code(func)
            print(type(code)) # <class 'str'>

667
        """
668 669 670
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
671
        # Gets AST from dygraph function
672 673 674 675 676 677 678 679 680 681 682 683
        raw_code = inspect.getsource(dygraph_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code

684
    def get_program_cache(self):
685
        """
686 687 688 689 690 691 692 693 694
        Returns the ProgramCache instance. This method is used by PaddlePaddle
        developers to manage program cache in ProgramTranslator. Normal users
        don't have to call this method.

        Returns:
            ProgramCache: ProgramCache instance of ProgramTranslator.

        Examples:
            .. code-block:: python
695

696 697 698 699 700
                import paddle.fluid as fluid

                prog_trans = fluid.dygraph.ProgramTranslator()
                prog_cache = prog_trans.get_program_cache()

701
        """
702
        return self._program_cache