program_translator.py 23.0 KB
Newer Older
1
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import gast
17
import inspect
18
import logging
19 20
import textwrap
import threading
21
import collections
22
import numpy as np
23
from paddle.fluid import core, scope_guard
24
from paddle.fluid import framework
25
from paddle.fluid import executor
26 27 28
from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers
from paddle.fluid.dygraph.base import switch_to_static_graph
29 30 31
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
32
from paddle.fluid.dygraph.base import param_guard
33
from paddle.fluid.data_feeder import check_type
34
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
35

36
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
37

38 39
logger = logging.getLogger("fluid")

40 41 42 43 44 45 46

class FunctionCache(object):
    """
    Caches the transformed functions to avoid redundant conversions of the same function.
    """

    def __init__(self):
47 48
        self._dycode_to_static_func = dict()
        self._static_func_to_transformer = dict()
49

50
    def get_or_cache_func(self, func):
51 52
        # code = self._get_dedent_code_string(func)
        static_func = self._dycode_to_static_func.get(func, None)
53 54

        if static_func is None:
55
            static_func, dygraph_to_static_transformer = convert_to_static(func)
56
            self._dycode_to_static_func[func] = static_func
57
            self._static_func_to_transformer[
58
                func] = dygraph_to_static_transformer
59 60 61

        return static_func

62 63
    def get_transformer(self, func):
        return self._static_func_to_transformer.get(func, None)
64

65
    def _get_dedent_code_string(self, func):
66
        raw_code = inspect.getsource(func)
67 68
        dedent_code = textwrap.dedent(raw_code)
        return dedent_code
69 70

    def exist(self, func):
71
        return self._dycode_to_static_func.get(func, None) is not None
72 73


74 75 76 77 78 79
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()


def convert_function_with_cache(dygraph_func):
    """
80
    Transforms function of dygraph into static function using the cache mechanism.
81 82 83 84 85 86
    """
    with _CACHE_LOCK:
        static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
        return static_func


87 88 89 90 91
class FunctionSpec(object):
    def __init__(self, func, args, kwargs):
        self._dyfunc = func
        self._args = args
        self._kwargs = kwargs
92

93 94
    def is_method(self):
        return self._args and isinstance(self._args[0], layers.Layer)
95

96
    def parameters(self, include_sublayer=True):
97
        params = collections.OrderedDict()
98 99 100
        if self.is_method():
            if include_sublayer:
                params = self._args[0].parameters()
101 102
                names = [p.name for p in params]
                params = collections.OrderedDict(zip(names, params))
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
            else:
                params = self._args[0]._parameters
        return params

    @switch_to_static_graph
    def to_static_inputs(self, main_program):
        inputs = []
        block = main_program.global_block()
        for input_var in self.args:
            if isinstance(input_var, np.ndarray):
                feed_layer = block.create_var(
                    name=unique_name.generate('feed'),
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    is_data=True,
                    need_check_feed=False)
            elif isinstance(input_var, core.VarBase):
                feed_layer = block.create_var(
                    name=input_var.name,
                    shape=list(input_var.shape),
                    dtype=input_var.dtype,
                    stop_gradient=input_var.stop_gradient,
                    need_check_feed=False)
            else:
                feed_layer = input_var
128

129 130
            inputs.append(feed_layer)
        return inputs
131

132 133 134
    @property
    def dyfunc(self):
        return self._dyfunc
135

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
    @property
    def args(self):
        return self._args

    def __key(self):
        # Note: if dygraph function is a method of class,
        # consider instance info as hash key.
        if self.is_method():
            return self._dyfunc, self._args[0]
        else:
            return self._dyfunc

    def __hash__(self):
        return hash(self.__key())

    def __eq__(self, other):
        return self.__key() == self.__key()


class ConcreteProgram(object):
    def __init__(self,
                 inputs,
                 outputs,
                 parameters,
                 func,
                 main_program,
162
                 startup_program=None):
163 164 165
        self.inputs = inputs
        self.outputs = outputs
        self.main_program = main_program
166
        self.startup_program = startup_program
167 168 169 170 171 172
        self.parameters = parameters
        self.func_spec = func

    @staticmethod
    @switch_to_static_graph
    def from_func_spec(func_spec):
173
        """
174 175
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.
176
        """
177
        # Transforms dygraph function into static function and caches it.
178 179
        dygaph_function = func_spec.dyfunc
        static_func = convert_function_with_cache(dygaph_function)
180

181 182 183 184
        main_program, startup_program = framework.Program(), framework.Program()
        # Note: The random seed should be synchronized into cached program
        # if set in `fluid.dygrap_guard` because some ops rely on it, such as
        # `fluid.layers.dropout`.
185
        main_program.random_seed = framework.default_main_program().random_seed
186 187
        startup_program.random_seed = framework.default_startup_program(
        ).random_seed
188

189
        with framework.program_guard(main_program, startup_program):
190 191
            # 1. Adds `fluid.data` layers for input if needed
            inputs = func_spec.to_static_inputs(main_program)
192

193
            # 2. Gets all ParamBases in the function
194
            all_parameters = list(func_spec.parameters().values())
195

196 197 198 199 200
            # 3. Builds program only once and returns the output Variables.
            with param_guard(func_spec.parameters(False)):
                outputs = static_func(*inputs)
            if not isinstance(outputs, (tuple, list)):
                outputs = [outputs] if outputs else []
201

202 203 204 205 206 207
        return ConcreteProgram(
            inputs=inputs,
            outputs=outputs,
            parameters=all_parameters,
            func=dygaph_function,
            main_program=main_program,
208
            startup_program=startup_program)
209 210


211 212 213 214
class ProgramCache(object):
    """
    Wrapper class for the program functions defined by dygraph function.
    """
215

216
    def __init__(self):
217
        self._caches = collections.OrderedDict()
218

219 220 221
    def _build_once(self, func_spec):
        concrete_program = ConcreteProgram.from_func_spec(func_spec)
        return concrete_program, partial_program_from(concrete_program)
222

223 224 225 226 227 228 229 230
    def __getitem__(self, item):
        if not isinstance(item, FunctionSpec):
            raise ValueError(
                'type(item) should be FunctionSpec, but received %s' %
                type(item))
        if item not in self._caches:
            self._caches[item] = self._build_once(item)
        return self._caches[item]
231

232 233 234 235 236 237
    def last(self):
        assert len(
            self._caches) >= 1, "No valid cached program in ProgramCache."
        key = next(reversed(self._caches.keys()))
        return key, self._caches[key]

238

239 240
def synchronized(func):
    func.__lock__ = threading.Lock()
241

242 243 244
    def lock_func(*args, **kwargs):
        with func.__lock__:
            return func(*args, **kwargs)
245

246
    return lock_func
247 248


249
class ProgramTranslator(object):
250
    """
251 252 253 254 255 256 257 258 259 260 261 262 263 264
    Class to translate dygraph function into static graph function. The object
    of this class is a singleton.

    Args:
        None.

    Returns:
        ProgramTranslator: the singleton object.

    Examples:
        .. code-block:: python

        import paddle.fluid as fluid

265
        # Two methods get same object because ProgramTranslator is a singleton
266 267 268
        fluid.dygraph.ProgramTranslator()
        fluid.dygraph.ProgramTranslator.get_instance()

269 270
    """

271
    _singleton_lock = threading.Lock()
272 273 274 275 276 277
    _instance = None

    @synchronized
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
278
            cls._instance._initialized = False
279 280 281 282 283
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
284 285
            with cls._singleton_lock:
                cls._instance = cls()
286 287 288 289 290
        return cls._instance

    @classmethod
    def reset(cls):
        if cls._instance is not None:
291
            cls._instance._initialized = False
292 293
            cls._instance.__init__()

294
    def __init__(self):
295
        # To make sure that calls __init__ only once.
296
        if self._initialized:
297
            return
298 299
        self._initialized = True
        self._program_cache = ProgramCache()
300 301
        self.enable_declarative = True

302
    def enable(self, enable_declarative):
303 304 305 306 307
        """
        Enable or disable the converting from imperative to declarative by
        ProgramTranslator globally.

        Args:
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331
            enable_declarative (bool): True or False to enable or disable declarative.

        Returns:
            None.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            @fluid.dygraph.jit.declarative
            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()
            prog_trans.enable(False)

            x = np.ones([1, 2])
L
liym27 已提交
332
            # The declarative is disabled so the func is run in dygraph
333 334
            with fluid.dygraph.guard():
                print(func(x).numpy()) # [[2. 2.]]
L
liym27 已提交
335

336
        """
337 338
        check_type(enable_declarative, "enable_declarative", bool,
                   "ProgramTranslator.enable")
339
        self.enable_declarative = enable_declarative
340

341 342
    def get_output(self, dygraph_func, *args, **kwargs):
        """
343 344 345 346 347 348
        Returns the output dygraph VarBase for dygraph function. The dygraph
        function will be translated into static graph function so the under
        beneath numerical result will be calculated by declarative mode.

        Args:
            dygraph_func (callable): the dygraph function.
L
liym27 已提交
349
            *args, **kwargs : the input argument of dygraph_func.
350 351 352 353

        Returns:
            VarBase or tuple of VarBase: the dygraph VarBase containing digital
                result.
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

371 372 373 374
                with fluid.dygraph.guard():
                    x = np.ones([1, 2])
                    x_v = prog_trans.get_output(func, x)
                    print(x_v.numpy()) # [[0. 0.]]
375

376
        """
377 378 379
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_output"
380
        if not self.enable_declarative:
381
            logger.info(
382 383
                "The ProgramTranslator.get_output doesn't work when setting ProgramTranslator.enable = False. "
                "We will just return dygraph output.")
384 385
            return dygraph_func(*args, **kwargs)

386 387 388 389 390 391 392
        function_spec = FunctionSpec(dygraph_func, args, kwargs)
        _, partial_program_layer = self._program_cache[function_spec]

        if args and isinstance(args[0], layers.Layer):
            args = args[1:]

        return partial_program_layer(args)
393 394 395

    def get_func(self, dygraph_func):
        """
396 397 398 399 400 401 402 403 404 405 406
        Returns a callable function which converts imperative dygraph APIs of
        the input dygraph_func into declarative net-building APIs, which means
        it doesn't return immediate digital result as get_output does.
        Users should handle Program and Executor by themselves.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
            callable: converting imperative dygraph APIs into declarative
            net-building APIs.
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                static_func = prog_trans.get_func(func)
                print(callable(static_func)) # True

427
        """
428 429 430
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_func"
431
        if not self.enable_declarative:
432
            logger.info(
433
                "The ProgramTranslator.get_func doesn't work when setting ProgramTranslator.enable=False. We will "
434
                "just return dygraph output.")
435
            return dygraph_func
436

437
        static_func = convert_function_with_cache(dygraph_func)
438 439
        return static_func

440 441
    def get_program(self, dygraph_func, *args, **kwargs):
        """
442
        Returns the translated static program and input/output variables from
443 444 445 446 447 448 449 450 451 452 453 454 455
        dygraph function. The users can use the program to run by executor.

        Args:
            dygraph_func (callable): the dygraph function.
            *args, **kwargs : the input argument of dygraph_func.

        Returns:
            tuple of (main_program, startup_program, inputs, outputs) whose
            types are (Program, Program, list of Variable, list of Variable).
            main_program: the converted main program.
            startup_program: the converted startup program.
            inputs: list of input Variables which need to be fed.
            outputs: list of output Variables which users can fetch.
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475

        Examples:
            .. code-block:: python

                import paddle.fluid as fluid
                import numpy as np

                def func(x):
                    x = fluid.dygraph.to_variable(x)
                    if fluid.layers.mean(x) > 0:
                        x_v = x - 1
                    else:
                        x_v = x + 1
                    return x_v

                prog_trans = fluid.dygraph.ProgramTranslator()

                x = np.ones([1, 2])
                main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
                print([i.name for i in inputs])
476
                # ['feed_0'] the feed input variable name representing x
477 478 479
                print([o.name for o in outputs])
                # ['_generated_var_4'] the fetch output variable name representing x_v        

480
        """
481 482 483
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_program"
484
        if not self.enable_declarative:
485
            logger.info(
486 487
                "The ProgramTranslator.get_program doesn't work when setting ProgramTranslator.enable=False."
                "We will just return dygraph output.")
488
            return dygraph_func(*args, **kwargs)
489

490 491
        func_spec = FunctionSpec(dygraph_func, args, kwargs)
        concrete_program, _ = self._program_cache[func_spec]
492 493 494 495 496 497 498 499 500 501
        # Note: concrete_program hold all input/output infos include non-Variable
        input_vars = [
            var for var in concrete_program.inputs
            if isinstance(var, framework.Variable)
        ]
        output_vars = [
            var for var in concrete_program.outputs
            if isinstance(var, framework.Variable)
        ]

502 503
        return concrete_program.main_program, \
               concrete_program.startup_program, \
504 505
               input_vars, \
               output_vars
506

507 508
    def get_code(self, dygraph_func):
        """
509 510 511 512 513 514
        Returns the translated static function string code from dygraph function.

        Args:
            dygraph_func (callable): the dygraph function.

        Returns:
515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535
            str: the string code of translated static function.

        Examples:
            .. code-block:: python

            import paddle.fluid as fluid
            import numpy as np

            def func(x):
                x = fluid.dygraph.to_variable(x)
                if fluid.layers.mean(x) > 0:
                    x_v = x - 1
                else:
                    x_v = x + 1
                return x_v

            prog_trans = fluid.dygraph.ProgramTranslator()

            code = prog_trans.get_code(func)
            print(type(code)) # <class 'str'>

536
        """
537 538 539
        assert callable(
            dygraph_func
        ), "Input dygraph_func is not a callable in ProgramTranslator.get_code"
540
        # Gets AST from dygraph function
541 542 543 544 545 546 547 548 549 550 551 552
        raw_code = inspect.getsource(dygraph_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code

553
    @switch_to_static_graph
554 555 556 557 558 559 560 561 562
    def save_inference_model(self, dirname, feed=None, fetch=None):
        """
        Saves current model as the inference model. It will prune the main_program
        to build a new program especially for inference, and then save it and all
        related parameters to given `dirname` . The saved inference model can be
        loaded by `:ref:`api_fluid_io_load_inference_model` or `C++ inference APIs.

        Args:
            dirname (str): the directory to save the inference model.
563 564 565 566 567 568 569 570
            feed (list[int], optional): the indices of the input variables of the
                dygraph functions which will be saved as input variables in
                inference model. If None, all input variables of the dygraph function
                would be the inputs of the saved inference model. Default None.
            fetch (list[int], optional): the indices of the returned variable of the
                dygraph functions which will be saved as output variables in
                inference model. If None, all output variables of the dygraph function
                would be the outputs of the saved inference model. Default None.
571 572 573 574 575 576 577
        Returns:
            None
        Examples:
            .. code-block:: python
                import numpy as np
                import paddle.fluid as fluid
                from paddle.fluid.dygraph import Linear
578
                from paddle.fluid.dygraph import declarative
579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602
                from paddle.fluid.dygraph import ProgramTranslator

                class SimpleNet(fluid.dygraph.Layer):
                    def __init__(self, in_size, out_size):
                        super(SimpleNet, self).__init__()
                        self._linear = Linear(in_size, out_size)

                    @declarative
                    def forward(self, x):
                        y = self._linear(x)
                        z = self._linear(y)
                        loss = fluid.layers.mean(z)
                        return z, loss

                with fluid.dygraph.guard(fluid.CPUPlace()):
                    net = SimpleNet(8, 8)
                    adam = fluid.optimizer.AdamOptimizer(learning_rate=0.1, parameter_list=net.parameters())
                    x = fluid.dygraph.to_variable(np.random.random((4, 8)).astype('float32'))
                    for i in range(10):
                        loss, out = net(x)
                        loss.backward()
                        adam.minimize(loss)
                        net.clear_gradients()
                # Save inference model.
603
                # Note that fetch=[0] means we set 'z' as the inference output.
604 605 606
                prog_trans = ProgramTranslator()
                prog_trans.save_inference_model("./dy2stat_infer_model", fetch=[0])

607 608
                # In this example, the inference model will be pruned based on output (z).
                # The pruned inference program is going to be saved in the folder
609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
                # "./dy2stat_infer_model" and parameters are going to be saved in separate
                # files in the folder.
        """

        def get_feed_fetch(var_list, partial_vars, return_name=False):
            vars = [
                var for var in var_list if isinstance(var, framework.Variable)
            ]
            if partial_vars:
                vars = [vars[idx] for idx in partial_vars]
            if return_name:
                vars = [var.name for var in vars]

            return vars

        func_spec, (concrete_program,
                    partial_layer) = self._program_cache.last()
        # share paramBase data with parameter
        scope = core.Scope()
        for param_base in concrete_program.parameters:
            param_tensor = scope.var(param_base.name).get_tensor()
            src_tensor = param_base.value().get_tensor()
            param_tensor._share_data_with(src_tensor)

        feed_var_names = get_feed_fetch(concrete_program.inputs, feed, True)
        fetch_vars = get_feed_fetch(concrete_program.outputs, fetch)

        from paddle.fluid.io import save_inference_model
        with scope_guard(scope):
            save_inference_model(
                dirname=dirname,
                feeded_var_names=feed_var_names,
                target_vars=fetch_vars,
                executor=executor.Executor(framework._current_expected_place()),
                main_program=concrete_program.main_program.clone())

645
    def get_program_cache(self):
646
        """
647 648 649 650 651 652 653 654 655
        Returns the ProgramCache instance. This method is used by PaddlePaddle
        developers to manage program cache in ProgramTranslator. Normal users
        don't have to call this method.

        Returns:
            ProgramCache: ProgramCache instance of ProgramTranslator.

        Examples:
            .. code-block:: python
656

657 658 659 660 661
                import paddle.fluid as fluid

                prog_trans = fluid.dygraph.ProgramTranslator()
                prog_cache = prog_trans.get_program_cache()

662
        """
663
        return self._program_cache