“63ce906b088c641c3bf33a2b8aa6324a39310ffe”上不存在“python/paddle/fluid/tests/book/test_label_semantic_roles.py”
program_translator.py 17.4 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import gast
17
import inspect
18
import logging
19 20
import textwrap
import threading
21 22
import numpy as np
from paddle.fluid import core
23
from paddle.fluid import framework
24 25 26
from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
27 28 29
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
30
from paddle.fluid.dygraph.base import param_guard
31
from paddle.fluid.data_feeder import check_type
32
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
33

34
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
35

36 37
logger = logging.getLogger("fluid")

38 39 40 41 42 43 44

class FunctionCache(object):
    """
    Caches the transformed functions to avoid redundant conversions of the same function.
    """

    def __init__(self):
45 46
        self._dycode_to_static_func = dict()
        self._static_func_to_transformer = dict()
47

48
    def get_or_cache_func(self, func):
49 50
        # code = self._get_dedent_code_string(func)
        static_func = self._dycode_to_static_func.get(func, None)
51 52

        if static_func is None:
53
            static_func, dygraph_to_static_transformer = convert_to_static(func)
54
            self._dycode_to_static_func[func] = static_func
55
            self._static_func_to_transformer[
56
                func] = dygraph_to_static_transformer
57 58 59

        return static_func

60 61
    def get_transformer(self, func):
        return self._static_func_to_transformer.get(func, None)
62

63
    def _get_dedent_code_string(self, func):
64
        raw_code = inspect.getsource(func)
65 66
        dedent_code = textwrap.dedent(raw_code)
        return dedent_code
67 68

    def exist(self, func):
69
        return self._dycode_to_static_func.get(func, None) is not None
70 71


72 73 74 75 76 77
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()


def convert_function_with_cache(dygraph_func):
    """
78
    Transforms function of dygraph into static function using the cache mechanism.
79 80 81 82 83 84
    """
    with _CACHE_LOCK:
        static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
        return static_func


85 86 87 88 89
class FunctionSpec(object):
    def __init__(self, func, args, kwargs):
        self._dyfunc = func
        self._args = args
        self._kwargs = kwargs
90

91 92
    def is_method(self):
        return self._args and isinstance(self._args[0], layers.Layer)
93

94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    def parameters(self, include_sublayer=True):
        params = {}
        if self.is_method():
            if include_sublayer:
                params = self._args[0].parameters()
            else:
                params = self._args[0]._parameters
        return params

    @switch_to_static_graph
    def to_static_inputs(self, main_program):
        inputs = []
        block = main_program.global_block()
        for input_var in self.args:
            if isinstance(input_var, np.ndarray):
                feed_layer = block.create_var(
                    name=unique_name.generate('feed'),
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    is_data=True,
                    need_check_feed=False)
            elif isinstance(input_var, core.VarBase):
                feed_layer = block.create_var(
                    name=input_var.name,
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    stop_gradient=input_var.stop_gradient,
                    need_check_feed=False)
            else:
                feed_layer = input_var
124

125 126
            inputs.append(feed_layer)
        return inputs
127

128 129 130
    @property
    def dyfunc(self):
        return self._dyfunc
131

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
    @property
    def args(self):
        return self._args

    def __key(self):
        # Note: if dygraph function is a method of class,
        # consider instance info as hash key.
        if self.is_method():
            return self._dyfunc, self._args[0]
        else:
            return self._dyfunc

    def __hash__(self):
        return hash(self.__key())

    def __eq__(self, other):
        return self.__key() == self.__key()


class ConcreteProgram(object):
    def __init__(self,
                 inputs,
                 outputs,
                 parameters,
                 func,
                 main_program,
                 start_up=None):
        self.inputs = inputs
        self.outputs = outputs
        self.main_program = main_program
        self.startup_program = start_up
        self.parameters = parameters
        self.func_spec = func

    @staticmethod
    @switch_to_static_graph
    def from_func_spec(func_spec):
169
        """
170 171
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.
172
        """
173
        # Transforms dygraph function into static function and caches it.
174 175
        dygaph_function = func_spec.dyfunc
        static_func = convert_function_with_cache(dygaph_function)
176

177 178 179 180
        main_program, start_up = framework.Program(), framework.Program()
        with framework.program_guard(main_program, start_up):
            # 1. Adds `fluid.data` layers for input if needed
            inputs = func_spec.to_static_inputs(main_program)
181

182 183
            # 2. Gets all ParamBases in the function
            all_parameters = func_spec.parameters()
184

185 186 187 188 189
            # 3. Builds program only once and returns the output Variables.
            with param_guard(func_spec.parameters(False)):
                outputs = static_func(*inputs)
            if not isinstance(outputs, (tuple, list)):
                outputs = [outputs] if outputs else []
190

191 192 193 194 195 196 197
        return ConcreteProgram(
            inputs=inputs,
            outputs=outputs,
            parameters=all_parameters,
            func=dygaph_function,
            main_program=main_program,
            start_up=start_up)
198 199


200 201 202 203
class ProgramCache(object):
    """
    Wrapper class for the program functions defined by dygraph function.
    """
204

205 206
    def __init__(self):
        self._caches = {}
207

208 209 210
    def _build_once(self, func_spec):
        concrete_program = ConcreteProgram.from_func_spec(func_spec)
        return concrete_program, partial_program_from(concrete_program)
211

212 213 214 215 216 217 218 219
    def __getitem__(self, item):
        if not isinstance(item, FunctionSpec):
            raise ValueError(
                'type(item) should be FunctionSpec, but received %s' %
                type(item))
        if item not in self._caches:
            self._caches[item] = self._build_once(item)
        return self._caches[item]
220 221


222 223
def synchronized(func):
    func.__lock__ = threading.Lock()
224

225 226 227
    def lock_func(*args, **kwargs):
        with func.__lock__:
            return func(*args, **kwargs)
228

229
    return lock_func
230 231


232
class ProgramTranslator(object):
233
    """
234 235 236 237 238 239 240 241 242 243 244 245 246 247
    Class to translate dygraph function into static graph function. The object
    of this class is a singleton.

    Args:
        None.

    Returns:
        ProgramTranslator: the singleton object.

    Examples:
        .. code-block:: python

        import paddle.fluid as fluid

248
        # Two methods get same object because ProgramTranslator is a singleton
249 250 251
        fluid.dygraph.ProgramTranslator()
        fluid.dygraph.ProgramTranslator.get_instance()

252 253
    """

254
    _singleton_lock = threading.Lock()
255 256 257 258 259 260
    _instance = None

    @synchronized
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
261
            cls._instance._initialized = False
262 263 264 265 266
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
267 268
            with cls._singleton_lock:
                cls._instance = cls()
269 270 271 272 273
        return cls._instance

    @classmethod
    def reset(cls):
        if cls._instance is not None:
274
            cls._instance._initialized = False
275 276
            cls._instance.__init__()

277
    def __init__(self):
278
        # To make sure that calls __init__ only once.
279
        if self._initialized:
280
            return
281 282
        self._initialized = True
        self._program_cache = ProgramCache()
283 284
        self.enable_declarative = True

285
    def enable(self, enable_declarative):
286 287 288 289 290
        """
        Enable or disable the converting from imperative to declarative by
        ProgramTranslator globally.

        Args:
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314
            enable_declarative (bool): True or False to enable or disable declarative.

        Returns:
            None.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            @fluid.dygraph.jit.declarative
            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()
            prog_trans.enable(False)

            x = np.ones([1, 2])
L
liym27 已提交
315
            # The declarative is disabled so the func is run in dygraph
316 317
            with fluid.dygraph.guard():
                print(func(x).numpy()) # [[2. 2.]]
L
liym27 已提交
318

319
        """
320 321
        check_type(enable_declarative, "enable_declarative", bool,
                   "ProgramTranslator.enable")
322
        self.enable_declarative = enable_declarative
323

324 325
    def get_output(self, dygraph_func, *args, **kwargs):
        """
326 327 328 329 330 331
        Returns the output dygraph VarBase for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by declarative mode.

        Args:
            dygraph_func (callable): the dygraph function.
L
liym27 已提交
332
            *args, **kwargs : the input argument of dygraph_func.
333 334 335 336

        Returns:
            VarBase or tuple of VarBase: the dygraph VarBase containing digital
                result.
337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                x = np.ones([1, 2])
                x_v = prog_trans.get_output(func, x)
                print(x_v.numpy()) # [[0. 0.]]

358
        """
359 360 361
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
362
        if not self.enable_declarative:
363
            logger.info(
364 365
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
                "We will just return dygraph output.")
366 367
            return dygraph_func(*args, **kwargs)

368 369 370 371 372 373 374
        function_spec = FunctionSpec(dygraph_func, args, kwargs)
        _, partial_program_layer = self._program_cache[function_spec]

        if args and isinstance(args[0], layers.Layer):
            args = args[1:]

        return partial_program_layer(args)
375 376 377

    def get_func(self, dygraph_func):
        """
378 379 380 381 382 383 384 385 386 387 388
        Returns a callable function which converts imperative dygraph APIs of
        the input dygraph_func into declarative net-building APIs, which means
        it doesn't return immediate digital result as get_output does.
        Users should handle Program and Executor by themselves.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
            callable: converting imperative dygraph APIs into declarative
            net-building APIs.
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                static_func = prog_trans.get_func(func)
                print(callable(static_func)) # True

409
        """
410 411 412
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
413
        if not self.enable_declarative:
414
            logger.info(
415
                "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
416
                "just return dygraph output.")
417
            return dygraph_func
418

419
        static_func = convert_function_with_cache(dygraph_func)
420 421
        return static_func

422 423
    def get_program(self, dygraph_func, *args, **kwargs):
        """
424
        Returns the translated static program and input/output variables from
425 426 427 428 429 430 431 432 433 434 435 436 437
        dygraph function. The users can use the program to run by executor.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            tuple of (main_program, startup_program, inputs, outputs) whose
            types are (Program, Program, list of Variable, list of Variable).
            main_program: the converted main program.
            startup_program: the converted startup program.
            inputs: list of input Variables which need to be fed.
            outputs: list of output Variables which users can fetch.
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                x = np.ones([1, 2])
                main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
                print([i.name for i in inputs])
                # ['x_0'] the feed input variable name representing x
                print([o.name for o in outputs])
                # ['_generated_var_4'] the fetch output variable name representing x_v        

462
        """
463 464 465
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
466
        if not self.enable_declarative:
467
            logger.info(
468 469
                "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
                "We will just return dygraph output.")
470
            return dygraph_func(*args, **kwargs)
471

472 473 474 475 476 477
        func_spec = FunctionSpec(dygraph_func, args, kwargs)
        concrete_program, _ = self._program_cache[func_spec]
        return concrete_program.main_program, \
               concrete_program.startup_program, \
               concrete_program.inputs, \
               concrete_program.outputs
478

479 480
    def get_code(self, dygraph_func):
        """
481 482 483 484 485 486
        Returns the translated static function string code from dygraph function.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507
            str: the string code of translated static function.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()

            code = prog_trans.get_code(func)
            print(type(code)) # <class 'str'>

508
        """
509 510 511
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
512
        # Gets AST from dygraph function
513 514 515 516 517 518 519 520 521 522 523 524
        raw_code = inspect.getsource(dygraph_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code

525
    def get_program_cache(self):
526
        """
527 528 529 530 531 532 533 534 535
        Returns the ProgramCache instance. This method is used by PaddlePaddle
        developers to manage program cache in ProgramTranslator. Normal users
        don't have to call this method.

        Returns:
            ProgramCache: ProgramCache instance of ProgramTranslator.

        Examples:
            .. code-block:: python
536

537 538 539 540 541
                import paddle.fluid as fluid

                prog_trans = fluid.dygraph.ProgramTranslator()
                prog_cache = prog_trans.get_program_cache()

542
        """
543
        return self._program_cache