program_translator.py 22.6 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import gast
17
import inspect
18
import warnings
19 20
import textwrap
import threading
21
import collections
22
import numpy as np
23
from paddle.fluid import core, scope_guard
24
from paddle.fluid import framework
25
from paddle.fluid import executor
26 27
from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers
28 29
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
30
from paddle.fluid.dygraph.base import switch_to_static_graph
31 32
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
33 34
from paddle.fluid.dygraph.dygraph_to_static.utils import func_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_func
35
from paddle.fluid.wrapped_decorator import signature_safe_contextmanager
36
from paddle.fluid.dygraph.base import param_guard
37
from paddle.fluid.data_feeder import check_type
38
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
39

40
__all__ = ['ProgramTranslator', 'convert_to_static']
41 42 43 44 45 46 47 48


class FunctionCache(object):
    """
    Caches the transformed functions to avoid redundant conversions of the same function.
    """

    def __init__(self):
49 50 51 52 53
        # Caches the converted static functions. {dygraph_func: static_func}
        self._converted_static_func_caches = dict()
        # Caches the converted ast node for same source code. {source_code: ast_root}
        self._code_to_ast_caches = dict()
        self._dygraph_to_static = DygraphToStaticAst()
54

55 56 57 58 59 60
    def convert_with_cache(self, func):
        """
        Returns the cached static function or converts it when first encounters the function.
        """
        # If hit cache, return it directly.
        static_func = self._converted_static_func_caches.get(func, None)
61 62

        if static_func is None:
63 64
            static_func = self._convert(func)
            self._converted_static_func_caches[func] = static_func
65 66 67

        return static_func

68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    def _convert(self, func):
        """
        Converts dygraph function into static function. For two functions with same dedent code,
        the second function will reuse the transformed ast node of previous one.

        For example:
            # A.py
            def foo(x, y):
                z = x + y
                return z

            # B.py
            def foo(x, y):
                z = x + y
                return z

        If the conversion of A.foo happens after B.foo, it will reuse the transformed ast node of B.foo
        to speed up the conversion.
        """
        # Note: In Python2, it will raise OSError when inspect function
        # with decorator directly and function.__wrapped__ holds the actual function.
        func = getattr(func, '__wrapped__', func)
        source_code = func_to_source_code(func)
        if source_code in self._code_to_ast_caches:
            root_wrapper = self._code_to_ast_caches[source_code]
        else:
            root = gast.parse(source_code)
            root_wrapper = self._dygraph_to_static.get_static_ast(root)
            self._code_to_ast_caches[source_code] = root_wrapper
97

98 99 100
        # Get static function from AST
        static_func, file_name = ast_to_func(root_wrapper.node, func)
        return static_func
101 102

    def exist(self, func):
103
        return func in self._converted_static_func_caches
104 105


106 107 108 109
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()


110
def convert_to_static(function):
111
    """
112
    Transforms function of dygraph into static function using the cache mechanism.
113 114 115

    Args:
        function(callable): The function with dygraph layers that will be converted into static layers.
116 117
    """
    with _CACHE_LOCK:
118
        static_func = _FUNCTION_CACHE.convert_with_cache(function)
119 120 121
        return static_func


122 123 124 125 126
class FunctionSpec(object):
    def __init__(self, func, args, kwargs):
        self._dyfunc = func
        self._args = args
        self._kwargs = kwargs
127

128 129 130
        dyfunc = getattr(func, '__wrapped__', func)
        self._dyfunc_code = inspect.getsource(dyfunc)

131 132
    def is_method(self):
        return self._args and isinstance(self._args[0], layers.Layer)
133

134
    def parameters(self, include_sublayer=True):
135 136 137 138
        """
        Returns parameters of decorated layers. If set `include_sublayer` True,
        the parameters created in sub layers will be added.
        """
139
        params = collections.OrderedDict()
140
        if self.is_method():
141
            layer_instance = self._args[0]
142
            if include_sublayer:
143
                params = layer_instance.parameters()
144 145
                names = [p.name for p in params]
                params = collections.OrderedDict(zip(names, params))
146
            else:
147
                params = layer_instance._parameters
148 149
        return params

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    def buffers(self, include_sublayer=True):
        """
        Returns Variable buffers of decorated layers. If set `include_sublayer` True,
        the Variable buffers created in sub layers will be added.
        """
        buffers = collections.OrderedDict()
        if self.is_method():
            layer_instance = self._args[0]
            if include_sublayer:
                buffers = layer_instance.buffers()
                names = [buffer.name for buffer in buffers]
                buffers = collections.OrderedDict(zip(names, buffers))
            else:
                buffers = layer_instance._buffers
        return buffers

166 167 168 169
    @switch_to_static_graph
    def to_static_inputs(self, main_program):
        inputs = []
        block = main_program.global_block()
170
        for input_var in flatten(self.args):
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186
            if isinstance(input_var, np.ndarray):
                feed_layer = block.create_var(
                    name=unique_name.generate('feed'),
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    is_data=True,
                    need_check_feed=False)
            elif isinstance(input_var, core.VarBase):
                feed_layer = block.create_var(
                    name=input_var.name,
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    stop_gradient=input_var.stop_gradient,
                    need_check_feed=False)
            else:
                feed_layer = input_var
187

188
            inputs.append(feed_layer)
189 190
        # Restores the nested structure as self.args
        return pack_sequence_as(self.args, inputs)
191

192 193 194
    @property
    def dyfunc(self):
        return self._dyfunc
195

196 197 198 199 200 201 202 203
    @property
    def args(self):
        return self._args

    def __key(self):
        # Note: if dygraph function is a method of class,
        # consider instance info as hash key.
        if self.is_method():
204 205 206
            # NOTE: we can use Layer's (instance + function code) as hash key.
            # An instance will not hold two identical methods 
            return self._dyfunc_code, self._args[0]
207 208 209 210 211 212 213 214 215 216
        else:
            return self._dyfunc

    def __hash__(self):
        return hash(self.__key())

    def __eq__(self, other):
        return self.__key() == self.__key()


217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238
# Flag that indicates whether running code under `@declarative`
_in_declarative_mode_ = False


def in_declarative_mode():
    """
    Return a bool value that indicates whether running code under `@declarative`

    """
    return _in_declarative_mode_


@signature_safe_contextmanager
def _switch_declarative_mode_guard_(is_declarative=True):

    global _in_declarative_mode_
    original_val = _in_declarative_mode_
    _in_declarative_mode_ = is_declarative
    yield
    _in_declarative_mode_ = original_val


239 240 241 242 243 244 245
class ConcreteProgram(object):
    def __init__(self,
                 inputs,
                 outputs,
                 parameters,
                 func,
                 main_program,
246
                 startup_program=None):
247 248 249
        self.inputs = inputs
        self.outputs = outputs
        self.main_program = main_program
250
        self.startup_program = startup_program
251 252 253 254 255 256
        self.parameters = parameters
        self.func_spec = func

    @staticmethod
    @switch_to_static_graph
    def from_func_spec(func_spec):
257
        """
258 259
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.
260
        """
261
        # Transforms dygraph function into static function and caches it.
262
        dygraph_function = func_spec.dyfunc
263
        static_func = convert_to_static(dygraph_function)
264

265 266
        main_program, startup_program = framework.Program(), framework.Program()
        # Note: The random seed should be synchronized into cached program
267
        # if set in `fluid.dygraph_guard` because some ops rely on it, such as
268
        # `fluid.layers.dropout`.
269
        main_program.random_seed = framework.default_main_program().random_seed
270 271
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed
272

273
        with framework.program_guard(main_program, startup_program):
274 275 276 277
            with _switch_declarative_mode_guard_(is_declarative=True):
                # 1. Adds `fluid.data` layers for input if needed
                inputs = func_spec.to_static_inputs(main_program)

278 279 280
                # 2. Gets all ParamBases and buffered VarBases in the function
                all_parameters_and_buffers = list(func_spec.parameters().values(
                )) + list(func_spec.buffers().values())
281 282

                # 3. Builds program only once and returns the output Variables.
283 284
                with param_guard(func_spec.parameters(False)), param_guard(
                        func_spec.buffers(False)):
285
                    outputs = static_func(*inputs)
286 287 288
                if not isinstance(outputs,
                                  (tuple, list)) and outputs is not None:
                    outputs = [outputs]
289

290 291 292
        return ConcreteProgram(
            inputs=inputs,
            outputs=outputs,
293
            parameters=all_parameters_and_buffers,
294
            func=dygraph_function,
295
            main_program=main_program,
296
            startup_program=startup_program)
297 298


299 300 301 302
class ProgramCache(object):
    """
    Wrapper class for the program functions defined by dygraph function.
    """
303

304
    def __init__(self):
305
        self._caches = collections.OrderedDict()
306

307 308 309
    def _build_once(self, func_spec):
        concrete_program = ConcreteProgram.from_func_spec(func_spec)
        return concrete_program, partial_program_from(concrete_program)
310

311 312 313 314 315 316 317 318
    def __getitem__(self, item):
        if not isinstance(item, FunctionSpec):
            raise ValueError(
                'type(item) should be FunctionSpec, but received %s' %
                type(item))
        if item not in self._caches:
            self._caches[item] = self._build_once(item)
        return self._caches[item]
319

320 321 322 323 324 325 326 327 328 329 330
    def get_program(self, item):
        if not isinstance(item, FunctionSpec):
            raise ValueError(
                "Input item's type should be FunctionSpec, but received %s" %
                type(item))
        if item not in self._caches:
            raise RuntimeError(
                "Failed to find program for input item, please decorate input function by `@declarative`."
            )
        return self._caches[item]

331 332 333 334 335 336
    def last(self):
        assert len(
            self._caches) >= 1, "No valid cached program in ProgramCache."
        key = next(reversed(self._caches.keys()))
        return key, self._caches[key]

337

338 339
def synchronized(func):
    func.__lock__ = threading.Lock()
340

341 342 343
    def lock_func(*args, **kwargs):
        with func.__lock__:
            return func(*args, **kwargs)
344

345
    return lock_func
346 347


348
class ProgramTranslator(object):
349
    """
350 351 352 353 354 355 356 357 358 359 360 361 362 363
    Class to translate dygraph function into static graph function. The object
    of this class is a singleton.

    Args:
        None.

    Returns:
        ProgramTranslator: the singleton object.

    Examples:
        .. code-block:: python

        import paddle.fluid as fluid

364
        # Two methods get same object because ProgramTranslator is a singleton
365 366 367
        fluid.dygraph.ProgramTranslator()
        fluid.dygraph.ProgramTranslator.get_instance()

368 369
    """

370
    _singleton_lock = threading.Lock()
371 372 373 374 375 376
    _instance = None

    @synchronized
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
377
            cls._instance._initialized = False
378 379 380 381 382
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
383 384
            with cls._singleton_lock:
                cls._instance = cls()
385 386 387 388 389
        return cls._instance

    @classmethod
    def reset(cls):
        if cls._instance is not None:
390
            cls._instance._initialized = False
391 392
            cls._instance.__init__()

393
    def __init__(self):
394
        # To make sure that calls __init__ only once.
395
        if self._initialized:
396
            return
397 398
        self._initialized = True
        self._program_cache = ProgramCache()
399 400
        self.enable_declarative = True

401
    def enable(self, enable_declarative):
402 403 404 405 406
        """
        Enable or disable the converting from imperative to declarative by
        ProgramTranslator globally.

        Args:
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
            enable_declarative (bool): True or False to enable or disable declarative.

        Returns:
            None.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            @fluid.dygraph.jit.declarative
            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()
            prog_trans.enable(False)

            x = np.ones([1, 2])
L
liym27 已提交
431
            # The declarative is disabled so the func is run in dygraph
432 433
            with fluid.dygraph.guard():
                print(func(x).numpy()) # [[2. 2.]]
L
liym27 已提交
434

435
        """
436 437
        check_type(enable_declarative, "enable_declarative", bool,
                   "ProgramTranslator.enable")
438
        self.enable_declarative = enable_declarative
439

440 441
    def get_output(self, dygraph_func, *args, **kwargs):
        """
442 443 444 445 446 447
        Returns the output dygraph VarBase for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by declarative mode.

        Args:
            dygraph_func (callable): the dygraph function.
L
liym27 已提交
448
            *args, **kwargs : the input argument of dygraph_func.
449 450 451 452

        Returns:
            VarBase or tuple of VarBase: the dygraph VarBase containing digital
                result.
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

470 471 472 473
                with fluid.dygraph.guard():
                    x = np.ones([1, 2])
                    x_v = prog_trans.get_output(func, x)
                    print(x_v.numpy()) # [[0. 0.]]
474

475
        """
476 477 478
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
479
        if not self.enable_declarative:
480
            warnings.warn(
481 482
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
                "We will just return dygraph output.")
483 484
            return dygraph_func(*args, **kwargs)

485 486 487 488
        function_spec = FunctionSpec(dygraph_func, args, kwargs)
        _, partial_program_layer = self._program_cache[function_spec]

        if args and isinstance(args[0], layers.Layer):
489 490
            # Synchronize self.training attribute.
            partial_program_layer.training = args[0].training
491 492 493
            args = args[1:]

        return partial_program_layer(args)
494 495 496

    def get_func(self, dygraph_func):
        """
497 498 499 500 501 502 503 504 505 506 507
        Returns a callable function which converts imperative dygraph APIs of
        the input dygraph_func into declarative net-building APIs, which means
        it doesn't return immediate digital result as get_output does.
        Users should handle Program and Executor by themselves.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
            callable: converting imperative dygraph APIs into declarative
            net-building APIs.
508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                static_func = prog_trans.get_func(func)
                print(callable(static_func)) # True

528
        """
529 530 531
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
532
        if not self.enable_declarative:
533
            warnings.warn(
534
                "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
535
                "just return dygraph output.")
536
            return dygraph_func
537

538
        static_func = convert_to_static(dygraph_func)
539 540
        return static_func

541 542
    def get_program(self, dygraph_func, *args, **kwargs):
        """
543
        Returns the translated static program and input/output variables from
544 545 546 547 548 549 550 551 552 553 554 555 556
        dygraph function. The users can use the program to run by executor.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            tuple of (main_program, startup_program, inputs, outputs) whose
            types are (Program, Program, list of Variable, list of Variable).
            main_program: the converted main program.
            startup_program: the converted startup program.
            inputs: list of input Variables which need to be fed.
            outputs: list of output Variables which users can fetch.
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                x = np.ones([1, 2])
                main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
                print([i.name for i in inputs])
577
                # ['feed_0'] the feed input variable name representing x
578 579 580
                print([o.name for o in outputs])
                # ['_generated_var_4'] the fetch output variable name representing x_v        

581
        """
582 583 584
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
585
        if not self.enable_declarative:
586
            warnings.warn(
587 588
                "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
                "We will just return dygraph output.")
589
            return dygraph_func(*args, **kwargs)
590

591 592
        func_spec = FunctionSpec(dygraph_func, args, kwargs)
        concrete_program, _ = self._program_cache[func_spec]
593 594 595 596 597 598 599 600 601 602
        # Note: concrete_program hold all input/output infos include non-Variable
        input_vars = [
            var for var in concrete_program.inputs
            if isinstance(var, framework.Variable)
        ]
        output_vars = [
            var for var in concrete_program.outputs
            if isinstance(var, framework.Variable)
        ]

603 604
        return concrete_program.main_program, \
               concrete_program.startup_program, \
605 606
               input_vars, \
               output_vars
607

608 609
    def get_code(self, dygraph_func):
        """
610 611 612 613 614 615
        Returns the translated static function string code from dygraph function.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636
            str: the string code of translated static function.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()

            code = prog_trans.get_code(func)
            print(type(code)) # <class 'str'>

637
        """
638 639 640
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
641
        # Gets AST from dygraph function
642 643 644 645 646 647 648 649 650 651 652 653
        raw_code = inspect.getsource(dygraph_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code

654
    def get_program_cache(self):
655
        """
656 657 658 659 660 661 662 663 664
        Returns the ProgramCache instance. This method is used by PaddlePaddle
        developers to manage program cache in ProgramTranslator. Normal users
        don't have to call this method.

        Returns:
            ProgramCache: ProgramCache instance of ProgramTranslator.

        Examples:
            .. code-block:: python
665

666 667 668 669 670
                import paddle.fluid as fluid

                prog_trans = fluid.dygraph.ProgramTranslator()
                prog_cache = prog_trans.get_program_cache()

671
        """
672
        return self._program_cache