program_translator.py 16.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
16
import gast
17 18 19
import inspect
import numpy
import six
20 21 22
import textwrap
import threading
import warnings
23
from collections import defaultdict
24 25 26

from paddle.fluid import framework
from paddle.fluid import core, executor
27
from paddle.fluid.dygraph import guard, to_variable
28 29 30
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
31
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import data_layer_not_check
32
from paddle.fluid.framework import in_dygraph_mode
33
from paddle.fluid.data_feeder import check_type
34

35
__all__ = ['ProgramTranslator', 'convert_function_with_cache']
36 37 38 39 40 41 42 43


class FunctionCache(object):
    """
    Caches the transformed functions to avoid redundant conversions of the same function.
    """

    def __init__(self):
44 45
        self._dycode_to_static_func = dict()
        self._static_func_to_transformer = dict()
46

47 48 49
    def get_or_cache_func(self, func):
        code = self._get_dedent_code_string(func)
        static_func = self._dycode_to_static_func.get(code, None)
50 51

        if static_func is None:
52 53 54 55
            static_func, dygraph_to_static_transformer = convert_to_static(func)
            self._dycode_to_static_func[code] = static_func
            self._static_func_to_transformer[
                static_func] = dygraph_to_static_transformer
56 57 58

        return static_func

59 60
    def get_transformer(self, func):
        return self._static_func_to_transformer.get(func, None)
61

62
    def _get_dedent_code_string(self, func):
63
        raw_code = inspect.getsource(func)
64 65
        dedent_code = textwrap.dedent(raw_code)
        return dedent_code
66 67

    def exist(self, func):
68 69
        return self._dycode_to_static_func.get(
            self._get_dedent_code_string(func), None) is not None
70 71


72 73 74 75 76 77 78 79 80 81 82 83 84
_CACHE_LOCK = threading.Lock()
_FUNCTION_CACHE = FunctionCache()


def convert_function_with_cache(dygraph_func):
    """
    Transform function of dygraph into static function using the cache mechanism.
    """
    with _CACHE_LOCK:
        static_func = _FUNCTION_CACHE.get_or_cache_func(dygraph_func)
        return static_func


85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
def synchronized(func):
    func.__lock__ = threading.Lock()

    def lock_func(*args, **kwargs):
        with func.__lock__:
            return func(*args, **kwargs)

    return lock_func


class ProgramCache(object):
    """
    Wrapper class for the program functions defined by dygraph function.
    """

    def __init__(self):
        self._inputs = []
        self._outputs = []
        # Always set program to default_main_program. Because once `__call__` is called,
        # it means layers(or Ops) are added into default_main_program switched by outer
        # `with` statement.
106 107
        self._main_program = framework.default_main_program()
        self._startup_program = framework.default_startup_program()
108 109 110 111 112 113
        self._func_cache = FunctionCache()
        # Stores the entry function of Net or Model.
        self._forward_func = None
        self._feed_name_to_idx = {}
        self._is_repeated = False
        # Indicates whether the function call is still building program.
114 115
        # Because user can call recursively when `Net` has sub class in
        # `forward()`.
116 117
        self._in_build_process = True

118
    def build_program_and_return_output(self, dyfunc, *args, **kwargs):
119
        """
120 121
        Executes the main_program with specialized inputs so that the program
        is built. This method also return outputs of program as fetch_list
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
        """
        # Transfroms dygraph function into static functions and caches them.
        static_func = self._transform_or_cache_layers(dyfunc)

        # 1. Adds `fluid.data` layers for input if needed
        if not self._inputs:
            self._add_feed_layers(args, kwargs)

        # 2. Avoids inserting forward ops repeatedly.
        if self._is_repeated:
            return self.outputs

        # 3. Builds program only once and returns the output Variables.
        outputs = self._get_or_build_program(static_func, args, kwargs)

        if static_func == self._forward_func:
            self._in_build_process = False

        return outputs

    def _transform_or_cache_layers(self, dyfunc):
        """
        Transforms dygraph function into static function.
        """
146
        static_func = self._func_cache.get_or_cache_func(dyfunc)
147 148
        # self._forward_func is entry function of Net or Model.
        # It can be called for multiple times, but layers from these functions
149
        # call stack will be added into self._main_program only once.
150 151 152 153 154 155 156 157 158 159 160 161 162 163
        # After that, cached program will be always returned by default.
        if static_func == self._forward_func:
            self._is_repeated = True

        if self._forward_func is None:
            self._forward_func = static_func

        return static_func

    def _get_or_build_program(self, func, args, kwargs):
        """
        Returns program of the input function. If called at first time,
        builds a new program and caches it.
        """
164
        with framework.program_guard(self._main_program, self._startup_program):
165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
            if func == self._forward_func:
                # Replaces input data with `layers.data`
                args = list(args)
                for feed_layer in self._inputs:
                    idx = self.feed_name_to_idx[feed_layer.name]
                    args[idx] = feed_layer
                fetch_list = func(*args, **kwargs)
                self._outputs = fetch_list
            else:
                fetch_list = func(*args, **kwargs)

        return fetch_list

    def _add_feed_layers(self, args, kwargs):
        """
        Adds `fluid.data` if the input `numpy.ndarray` is converted into `Variable`
        by `to_variable()`, it makes program to be executed dynamically.
        """
        if not self._feed_name_to_idx:
            self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
185
        with framework.program_guard(self._main_program, self._startup_program):
186 187 188 189 190 191
            for feed_name, idx in self.feed_name_to_idx.items():
                batch_data = args[idx]
                assert isinstance(
                    batch_data, numpy.ndarray
                ), "Input {} should be numpy.ndarray, but received {}.".format(
                    feed_name, type(batch_data))
192
                feed_layer = data_layer_not_check(
193
                    name=feed_name,
194
                    shape=list(batch_data.shape),
195 196 197 198 199 200 201 202
                    dtype=str(batch_data.dtype))
                self._inputs.append(feed_layer)

    def _get_name_to_idx(self, func):
        """
        Returns name and index of input args from `forward(args)`
        that need to be replaced with `fluid.data`.
        """
203
        transformer = self._func_cache.get_transformer(func)
204 205 206 207
        feed_name_to_idx = transformer.get_feed_name_to_idx()
        return feed_name_to_idx

    @property
208 209 210 211 212 213
    def main_program(self):
        return self._main_program

    @property
    def startup_program(self):
        return self._startup_program
214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231

    @property
    def inputs(self):
        return self._inputs

    @property
    def outputs(self):
        return self._outputs

    @property
    def feed_name_to_idx(self):
        return self._feed_name_to_idx

    @property
    def in_build_process(self):
        return self._in_build_process


232
class ProgramTranslator(object):
233
    _singleton_lock = threading.Lock()
234 235 236 237 238 239
    _instance = None

    @synchronized
    def __new__(cls, *args, **kwargs):
        if cls._instance is None:
            cls._instance = object.__new__(cls, *args, **kwargs)
240
            cls._instance._initialized = False
241 242 243 244 245
        return cls._instance

    @classmethod
    def get_instance(cls):
        if cls._instance is None:
246 247
            with cls._singleton_lock:
                cls._instance = cls()
248 249 250 251 252
        return cls._instance

    @classmethod
    def reset(cls):
        if cls._instance is not None:
253
            cls._instance._initialized = False
254 255 256 257
            cls._instance.__init__()

    def __init__(self, exe=None, place=None):
        # To make sure that calls __init__ only once.
258
        if self._initialized:
259
            return
260
        self._initialized = True
261 262 263 264 265
        self._place = core.CPUPlace() if place is None else place
        if exe is None:
            self._exe = executor.Executor(self._place)
        else:
            self._exe = exe
266
        self._program_cache = ProgramCache()
267
        self._optimizer_info = None
268
        self._optimizer = None
269
        self._loss_name = None
270 271 272
        # Once main_program is changed, should run startup_program.
        self._need_startup = True

273 274
    def get_output(self, dygraph_func, *args, **kwargs):
        """
275
        Return the output tensors for dygraph function and its arguments
276 277 278
        """
        if in_dygraph_mode():
            warnings.warn(
279
                "The ProgramTranslator.get_output doesn't work in dygraph "
280
                "mode. We will just return dygraph output. Use it in "
281
                "static mode if you would like to translate to static graph.")
282 283 284 285 286 287 288
            return dygraph_func(*args, **kwargs)

        program_cache = self.get_program_cache()
        outputs = program_cache.build_program_and_return_output(dygraph_func,
                                                                *args, **kwargs)
        if not program_cache.in_build_process:
            outputs = self.run(*args, **kwargs)
289 290
            with guard():
                outputs = [to_variable(x) for x in outputs]
291 292 293 294
        return outputs

    def get_func(self, dygraph_func):
        """
295
        Return the translated static function from dygraph function
296 297 298
        """
        if in_dygraph_mode():
            warnings.warn(
299
                "The ProgramTranslator.get_func doesn't work in dygraph "
300
                "mode. We will just return dygraph function. Use it in "
301
                "static mode if you would like to translate to static graph.")
302
            return dygraph_func
303
        static_func = convert_function_with_cache(dygraph_func)
304 305
        return static_func

306 307
    def get_program(self, dygraph_func, *args, **kwargs):
        """
308
        Return the translated static program and input/output variables from
309 310 311 312 313 314 315 316 317 318 319 320 321
        dygraph function.
        """
        if in_dygraph_mode():
            warnings.warn(
                "The ProgramTranslator.get_program doesn't work in dygraph "
                "mode. We will just return dygraph output. Use it in static "
                "mode if you would like to translate to static graph.")
            return dygraph_func(*args, **kwargs)
        program_cache = self.get_program_cache()
        outputs = program_cache.build_program_and_return_output(dygraph_func,
                                                                *args, **kwargs)
        return self.main_program, self.startup_program, program_cache.inputs, outputs

322 323
    def get_code(self, dygraph_func):
        """
324
        Return the translated static function code from dygraph code
325 326 327 328 329 330 331 332 333 334 335 336 337 338
        """
        # Get AST from dygraph function
        raw_code = inspect.getsource(dygraph_func)
        code = textwrap.dedent(raw_code)
        root = gast.parse(code)

        # Transform AST
        dygraph_to_static = DygraphToStaticAst()
        root_wrapper = dygraph_to_static.get_static_ast(root)

        # Get source_code
        source_code = ast_to_source_code(root_wrapper.node)
        return source_code

339 340
    def run(self, *args, **kwargs):
        """
341
        Execute main_program and returns output Tensors.
342 343 344
        """
        feed_dict, fetch_list = self._prepare(args)

345
        main_program = self._program_cache.main_program
346 347 348 349 350 351
        outputs = self._exe.run(main_program,
                                feed=feed_dict,
                                fetch_list=fetch_list)

        return outputs

352
    def set_optimizer(self, optimizer, index_of_loss=0):
353
        """
354
        Support to set or update the optimizer used to minimize loss.
355
        """
356 357
        check_type(index_of_loss, "index_of_loss", int,
                   "ProgramTranslator.set_optimizer")
358
        self._check_cache_valid()
359
        if self._optimizer and self._loss_name:
360
            raise ValueError(
361 362 363
                "{} for {} has already been set before. Please confirm not to call `set_optimizer` in for loop. ".
                format(self._optimizer, self._loss_name))
        self._optimizer_info = (optimizer, index_of_loss)
364

365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383
    def save_inference_model(self, dirname, feed=None, fetch=None):
        """
        Save current model as the inference model.
        """
        program_cache = self.get_program_cache()
        if feed is None:
            feeded_var_names = [i.name for i in program_cache.inputs]
        else:
            feeded_var_names = [program_cache.inputs[i].name for i in feed]

        target_vars = program_cache.outputs
        from paddle.fluid.io import save_inference_model
        save_inference_model(
            dirname=dirname,
            feeded_var_names=feeded_var_names,
            target_vars=target_vars,
            executor=self._exe,
            main_program=self.main_program.clone())

384 385
    def _prepare(self, args):
        """
386
        Prepare with feed_dict, fetch_list, optimizer and initialize vars
387 388 389
        by running startup_program.
        """

390
        # Update batch_data for feed_dict
391
        feed_dict = self._update_batch_data(args)
392
        fetch_list = self._program_cache.outputs
393

394 395
        # Add optimizer if needed.
        if self._optimizer_info and self._optimizer is None:
396 397 398
            self._add_optimizer()

        if self._need_startup:
399
            self._exe.run(self.startup_program)
400 401 402 403 404 405
            self._need_startup = False

        return feed_dict, fetch_list

    def _check_cache_valid(self):
        """
406
        Check whether the current program is consistent with `default_main_program`.
407 408 409
        In some models and unittest, program will be switched frequently by `program_guard`.
        If does, the cached program and other properties are not available and should be reset.
        """
410 411 412
        if self._program_cache.main_program:
            if self._program_cache.main_program != framework.default_main_program(
            ):
413
                ProgramTranslator.reset()
414 415 416

    def _update_batch_data(self, args):
        """
417
        Update cached batch data while training program.
418
        """
419 420
        feed_name_to_idx = self._program_cache.feed_name_to_idx
        feed_vars = self._program_cache.inputs
421 422 423 424 425 426 427 428 429
        feed_dict = {}
        for feed_var in feed_vars:
            idx = feed_name_to_idx[feed_var.name]
            feed_dict[feed_var.name] = args[idx]

        return feed_dict

    def _add_optimizer(self):
        """
430
        Support to set or update the optimizer used to minimize loss.
431
        """
432 433 434 435 436 437 438 439 440 441 442 443 444 445
        optimizer, index_of_loss = self._optimizer_info

        outputs = self._program_cache.outputs
        outputs = [outputs] if not isinstance(outputs,
                                              (list, tuple)) else outputs

        assert abs(index_of_loss) < len(outputs), \
            "index_of_loss: {} shall not exceed the length of outputs: {}.".format(
            index_of_loss, len(outputs))

        loss_var = outputs[index_of_loss]
        check_type(loss_var, "loss_var", framework.Variable,
                   "ProgramTranslator._add_optimizer")

446 447
        main_program = self._program_cache.main_program
        startup_program = self._program_cache.startup_program
448 449
        all_vars = main_program.block(0).vars

450
        if all_vars.get(loss_var.name, None) is None:
451
            raise ValueError(
452 453 454
                "Can't find {} in main_program, please confirm whether the input loss is correct."
                .format(loss_var.name))
        # Add optimizer to minimize loss
455
        with framework.program_guard(main_program, startup_program):
456
            optimizer.minimize(loss_var)
457

458 459
        self._optimizer = optimizer
        self._loss_name = loss_var.name
460

461
    def get_program_cache(self):
462
        """
463
        Return the ProgramCache instance.
464 465
        """
        self._check_cache_valid()
466
        return self._program_cache
467 468

    @property
469 470 471 472 473 474
    def main_program(self):
        return self._program_cache.main_program

    @property
    def startup_program(self):
        return self._program_cache.startup_program