program_translator.py 17.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import gast
17 18
import inspect
import numpy
19 20 21
import textwrap
import threading
import warnings
22 23 24

from paddle.fluid import framework
from paddle.fluid import core, executor
25
from paddle.fluid.dygraph import guard, to_variable
26 27 28
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
29
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
30
from paddle.fluid.framework import in_dygraph_mode
31
from paddle.fluid.data_feeder import check_type
32

33
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
34 35 36 37 38 39 40 41


class FunctionCache(object):
    """
    Caches the transformed functions to avoid redundant conversions of the same function.
    """

    def __init__(self):
42 43
        self._dycode_to_static_func = dict()
        self._static_func_to_transformer = dict()
44

45 46 47
    def get_or_cache_func(self, func):
        code = self._get_dedent_code_string(func)
        static_func = self._dycode_to_static_func.get(code, None)
48 49

        if static_func is None:
50 51 52 53
            static_func, dygraph_to_static_transformer = convert_to_static(func)
            self._dycode_to_static_func[code] = static_func
            self._static_func_to_transformer[
                static_func] = dygraph_to_static_transformer
54 55 56

        return static_func

57 58
    def get_transformer(self, func):
        return self._static_func_to_transformer.get(func, None)
59

60
    def _get_dedent_code_string(self, func):
61
        raw_code = inspect.getsource(func)
62 63
        dedent_code = textwrap.dedent(raw_code)
        return dedent_code
64 65

    def exist(self, func):
66 67
        return self._dycode_to_static_func.get(
            self._get_dedent_code_string(func), None) is not None
68 69


70 71 72 73 74 75
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()


def convert_function_with_cache(dygraph_func):
    """
76
    Transforms function of dygraph into static function using the cache mechanism.
77 78 79 80 81 82
    """
    with _CACHE_LOCK:
        static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
        return static_func


83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
def synchronized(func):
    func.__lock__ = threading.Lock()

    def lock_func(*args, **kwargs):
        with func.__lock__:
            return func(*args, **kwargs)

    return lock_func


class ProgramCache(object):
    """
    Wrapper class for the program functions defined by dygraph function.
    """

    def __init__(self):
        self._inputs = []
        self._outputs = []
        # Always set program to default_main_program. Because once `__call__` is called,
        # it means layers(or Ops) are added into default_main_program switched by outer
        # `with` statement.
104 105
        self._main_program = framework.default_main_program()
        self._startup_program = framework.default_startup_program()
106
        self._func_cache = FunctionCache()
107
        self._feed_name_to_idx = {}
108 109 110 111
        # Stores the entry function of Net or Model.
        self._forward_func = None
        self._is_repeated = False
        # Indicates whether the function call is still building program.
112 113
        # Because user can call recursively when `Net` has sub class in
        # `forward()`.
114 115
        self._in_build_process = True

116
    def build_program_and_return_output(self, dyfunc, *args, **kwargs):
117
        """
118 119
        Builds the main_program with specialized inputs and returns outputs
        of program as fetch_list.
120
        """
121
        # Transforms dygraph function into static function and caches it.
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
        static_func = self._transform_or_cache_layers(dyfunc)

        # 1. Adds `fluid.data` layers for input if needed
        if not self._inputs:
            self._add_feed_layers(args, kwargs)

        # 2. Avoids inserting forward ops repeatedly.
        if self._is_repeated:
            return self.outputs

        # 3. Builds program only once and returns the output Variables.
        outputs = self._get_or_build_program(static_func, args, kwargs)

        if static_func == self._forward_func:
            self._in_build_process = False

        return outputs

    def _transform_or_cache_layers(self, dyfunc):
        """
        Transforms dygraph function into static function.
        """
144
        static_func = self._func_cache.get_or_cache_func(dyfunc)
145 146 147

        if self._forward_func is None:
            self._forward_func = static_func
148 149 150 151 152 153 154 155 156 157 158 159 160 161
        else:
            # self._forward_func is entry function of Net or Model.
            # It can be called for multiple times, but layers from these functions
            # call stack will be added into self._main_program only once.
            # After that, cached program will be always returned by default.
            if static_func == self._forward_func:
                self._is_repeated = True
            # If a independent function is received after the build process
            # has finished, feed layers should be reset.
            # TODO(Aurelius84): Switch main_program without specifying program_guard.
            elif not self._in_build_process:
                self._inputs = []
                self._is_repeated = False
                self._forward_func = static_func
162 163 164 165 166 167 168 169

        return static_func

    def _get_or_build_program(self, func, args, kwargs):
        """
        Returns program of the input function. If called at first time,
        builds a new program and caches it.
        """
170
        with framework.program_guard(self._main_program, self._startup_program):
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
            if func == self._forward_func:
                # Replaces input data with `layers.data`
                args = list(args)
                for feed_layer in self._inputs:
                    idx = self.feed_name_to_idx[feed_layer.name]
                    args[idx] = feed_layer
                fetch_list = func(*args, **kwargs)
                self._outputs = fetch_list
            else:
                fetch_list = func(*args, **kwargs)

        return fetch_list

    def _add_feed_layers(self, args, kwargs):
        """
        Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
        by `to_variable()`, it makes program to be executed dynamically.
        """
189
        self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
190
        with framework.program_guard(self._main_program, self._startup_program):
191 192 193 194 195 196
            for feed_name, idx in self.feed_name_to_idx.items():
                batch_data = args[idx]
                assert isinstance(
                    batch_data, numpy.ndarray
                ), "Input {} should be numpy.ndarray, but received {}.".format(
                    feed_name, type(batch_data))
197
                feed_layer = data_layer_not_check(
198
                    name=feed_name,
199
                    shape=list(batch_data.shape),
200 201 202 203 204 205 206 207
                    dtype=str(batch_data.dtype))
                self._inputs.append(feed_layer)

    def _get_name_to_idx(self, func):
        """
        Returns name and index of input args from `forward(args)`
        that need to be replaced with `fluid.data`.
        """
208
        transformer = self._func_cache.get_transformer(func)
209 210 211 212
        feed_name_to_idx = transformer.get_feed_name_to_idx()
        return feed_name_to_idx

    @property
213 214 215 216 217 218
    def main_program(self):
        return self._main_program

    @property
    def startup_program(self):
        return self._startup_program
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236

    @property
    def inputs(self):
        return self._inputs

    @property
    def outputs(self):
        return self._outputs

    @property
    def feed_name_to_idx(self):
        return self._feed_name_to_idx

    @property
    def in_build_process(self):
        return self._in_build_process


237
class ProgramTranslator(object):
238
    _singleton_lock = threading.Lock()
239 240 241 242 243 244
    _instance = None

    @synchronized
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
245
            cls._instance._initialized = False
246 247 248 249 250
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
251 252
            with cls._singleton_lock:
                cls._instance = cls()
253 254 255 256 257
        return cls._instance

    @classmethod
    def reset(cls):
        if cls._instance is not None:
258
            cls._instance._initialized = False
259 260 261 262
            cls._instance.__init__()

    def __init__(self, exe=None, place=None):
        # To make sure that calls __init__ only once.
263
        if self._initialized:
264
            return
265
        self._initialized = True
266 267 268 269 270
        self._place = core.CPUPlace() if place is None else place
        if exe is None:
            self._exe = executor.Executor(self._place)
        else:
            self._exe = exe
271
        self._program_cache = ProgramCache()
272
        self._optimizer_info = None
273
        self._optimizer = None
274
        self._loss_name = None
275 276
        # Once startup_program is changed, should run startup_program.
        self._prev_startup = None
277

278 279
    def get_output(self, dygraph_func, *args, **kwargs):
        """
280
        Returns the output tensors for dygraph function and its arguments
281 282 283
        """
        if in_dygraph_mode():
            warnings.warn(
284
                "The ProgramTranslator.get_output doesn't work in dygraph "
285
                "mode. We will just return dygraph output. Use it in "
286
                "static mode if you would like to translate to static graph.")
287 288 289 290 291 292 293
            return dygraph_func(*args, **kwargs)

        program_cache = self.get_program_cache()
        outputs = program_cache.build_program_and_return_output(dygraph_func,
                                                                *args, **kwargs)
        if not program_cache.in_build_process:
            outputs = self.run(*args, **kwargs)
294 295
            with guard():
                outputs = [to_variable(x) for x in outputs]
296 297 298 299
        return outputs

    def get_func(self, dygraph_func):
        """
300
        Returns the translated static function from dygraph function
301 302 303
        """
        if in_dygraph_mode():
            warnings.warn(
304
                "The ProgramTranslator.get_func doesn't work in dygraph "
305
                "mode. We will just return dygraph function. Use it in "
306
                "static mode if you would like to translate to static graph.")
307
            return dygraph_func
308
        static_func = convert_function_with_cache(dygraph_func)
309 310
        return static_func

311 312
    def get_program(self, dygraph_func, *args, **kwargs):
        """
313
        Returns the translated static program and input/output variables from
314 315 316 317 318 319 320 321 322 323 324 325 326
        dygraph function.
        """
        if in_dygraph_mode():
            warnings.warn(
                "The ProgramTranslator.get_program doesn't work in dygraph "
                "mode. We will just return dygraph output. Use it in static "
                "mode if you would like to translate to static graph.")
            return dygraph_func(*args, **kwargs)
        program_cache = self.get_program_cache()
        outputs = program_cache.build_program_and_return_output(dygraph_func,
                                                                *args, **kwargs)
        return self.main_program, self.startup_program, program_cache.inputs, outputs

327 328
    def get_code(self, dygraph_func):
        """
329
        Returns the translated static function code from dygraph code
330
        """
331
        # Gets AST from dygraph function
332 333 334 335 336 337 338 339 340 341 342 343
        raw_code = inspect.getsource(dygraph_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code

344 345
    def run(self, *args, **kwargs):
        """
346
        Executes main_program and returns output Tensors.
347 348 349
        """
        feed_dict, fetch_list = self._prepare(args)

350
        main_program = self._program_cache.main_program
351 352 353 354 355 356
        outputs = self._exe.run(main_program,
                                feed=feed_dict,
                                fetch_list=fetch_list)

        return outputs

357
    def set_optimizer(self, optimizer, index_of_loss=0):
358
        """
359
        Supports to set or update the optimizer used to minimize loss.
360
        """
361 362
        check_type(index_of_loss, "index_of_loss", int,
                   "ProgramTranslator.set_optimizer")
363
        self._check_cache_valid()
364
        if self._optimizer and self._loss_name:
365
            raise ValueError(
366 367 368
                "{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop. ".
                format(self._optimizer, self._loss_name))
        self._optimizer_info = (optimizer, index_of_loss)
369

370 371
    def save_inference_model(self, dirname, feed=None, fetch=None):
        """
372
        Saves current model as the inference model.
373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388
        """
        program_cache = self.get_program_cache()
        if feed is None:
            feeded_var_names = [i.name for i in program_cache.inputs]
        else:
            feeded_var_names = [program_cache.inputs[i].name for i in feed]

        target_vars = program_cache.outputs
        from paddle.fluid.io import save_inference_model
        save_inference_model(
            dirname=dirname,
            feeded_var_names=feeded_var_names,
            target_vars=target_vars,
            executor=self._exe,
            main_program=self.main_program.clone())

389 390
    def _prepare(self, args):
        """
391
        Prepares with feed_dict, fetch_list, optimizer and initialize vars
392 393 394
        by running startup_program.
        """

395
        # Updates batch_data for feed_dict
396
        feed_dict = self._update_batch_data(args)
397
        fetch_list = self._program_cache.outputs
398

399
        # Adds optimizer if needed.
400
        if self._optimizer_info and self._optimizer is None:
401 402
            self._add_optimizer()

403
        if self._need_startup():
404
            self._exe.run(self.startup_program)
405
            self._prev_startup = self.startup_program
406 407 408

        return feed_dict, fetch_list

409 410 411 412 413 414 415 416 417 418 419
    def _need_startup(self):
        """
        Determines whether needy to run startup_program.
        """
        if self.startup_program != self._prev_startup:
            check_type(self.startup_program, "startup_program",
                       framework.Program, "_need_startup")
            return len(self.startup_program.global_block().ops) > 0

        return False

420 421
    def _check_cache_valid(self):
        """
422
        Checks whether the current program is consistent with `default_main_program`.
423 424 425
        In some models and unittest, program will be switched frequently by `program_guard`.
        If does, the cached program and other properties are not available and should be reset.
        """
426 427 428
        if self._program_cache.main_program:
            if self._program_cache.main_program != framework.default_main_program(
            ):
429
                ProgramTranslator.reset()
430 431 432

    def _update_batch_data(self, args):
        """
433
        Updates cached batch data while training program.
434
        """
435 436
        feed_name_to_idx = self._program_cache.feed_name_to_idx
        feed_vars = self._program_cache.inputs
437 438 439 440 441 442 443 444 445
        feed_dict = {}
        for feed_var in feed_vars:
            idx = feed_name_to_idx[feed_var.name]
            feed_dict[feed_var.name] = args[idx]

        return feed_dict

    def _add_optimizer(self):
        """
446
        Supports to set or update the optimizer used to minimize loss.
447
        """
448 449 450 451 452 453 454 455 456 457 458 459 460 461
        optimizer, index_of_loss = self._optimizer_info

        outputs = self._program_cache.outputs
        outputs = [outputs] if not isinstance(outputs,
                                              (list, tuple)) else outputs

        assert abs(index_of_loss) < len(outputs), \
            "index_of_loss: {} shall not exceed the length of outputs: {}.".format(
            index_of_loss, len(outputs))

        loss_var = outputs[index_of_loss]
        check_type(loss_var, "loss_var", framework.Variable,
                   "ProgramTranslator._add_optimizer")

462 463
        main_program = self._program_cache.main_program
        startup_program = self._program_cache.startup_program
464 465
        all_vars = main_program.block(0).vars

466
        if all_vars.get(loss_var.name, None) is None:
467
            raise ValueError(
468 469
                "Can't find {} in main_program, please confirm whether the input loss is correct."
                .format(loss_var.name))
470
        # Adds optimizer to minimize loss
471
        with framework.program_guard(main_program, startup_program):
472
            optimizer.minimize(loss_var)
473

474 475
        self._optimizer = optimizer
        self._loss_name = loss_var.name
476

477
    def get_program_cache(self):
478
        """
479
        Returns the ProgramCache instance.
480 481
        """
        self._check_cache_valid()
482
        return self._program_cache
483 484

    @property
485 486 487 488 489 490
    def main_program(self):
        return self._program_cache.main_program

    @property
    def startup_program(self):
        return self._program_cache.startup_program