partial_program.py 23.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import numpy as np
17
import six
18

19
import paddle
20
from paddle.fluid import framework, backward, core, program_guard
21
from paddle.fluid.dygraph import layers
22
from paddle.fluid.dygraph.base import switch_to_static_graph
23
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
24
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
25 26
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
27 28
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
29
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
30 31
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
32
import paddle.compat as cpt
W
wanghuancoder 已提交
33
from paddle import _C_ops
34

35 36 37 38 39 40 41 42 43

class NestSequence(object):
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
44
        self.__input_list = self.tolist()
45 46 47 48 49 50 51 52 53 54 55 56 57
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
        return flatten(self.__raw_input)

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
58
        assert len(self.__input_list) == len(value_list)
59 60 61 62
        return pack_sequence_as(self.__raw_input, value_list)

    def _get_var_ids(self):
        var_ids = []
63
        for idx, var in enumerate(self.__input_list):
J
Jiabin Yang 已提交
64 65
            if isinstance(var, (framework.Variable, core.VarBase,
                                core.eager.Tensor)):
66 67 68 69 70 71 72 73 74 75
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
76
            for var in self.__input_list:
J
Jiabin Yang 已提交
77 78
                if not isinstance(var, (framework.Variable, core.VarBase,
                                        core.eager.Tensor)):
79 80
                    warning_types.add(type(var))
            if warning_types:
81
                logging_utils.warn(
82 83 84 85 86 87 88 89 90 91
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
                    "what we first saw. Please try to return them as tensor.".
                    format(list(warning_types)))

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
92
        return self.__input_list[item]
93

94

95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
class LazyInitialized(object):
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


def _change_is_test_status(program, is_test):
    # change all `is_test` attributes
    for block in program.blocks:
        for op in block.ops:
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)
    return program


118
class PartialProgramLayer:
119 120 121 122 123
    """
    PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
    and execute them as a static subgraph.

    .. note::
124 125 126
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
127 128 129 130 131 132 133 134 135 136 137 138
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
        inputs(list[Variable]): The input list of the decorated function by `@declarative`.
        outputs(list[Variable]): The output list of the decorated function by `@declarative`.
        parameters(list[VarBase]|None): All trainable parameters included in the program. Default None.

    Returns:
        Layer: A Layer object that run all ops internally in static mode.
    """

139 140
    def __init__(self, main_program, inputs, outputs, parameters=None,
                 **kwargs):
141
        super(PartialProgramLayer, self).__init__()
142 143
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
144
        self._params = parameters if parameters is not None else []
145

146 147 148
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

149
        self._origin_main_program = self._verify_program(main_program)
150 151 152
        self._tmp_scope_vec = self._create_scope_vec()
        # A fake_var to handle empty input or output
        self.__fake_vars = _create_fake_var()
153
        # Set default mode to train
154
        self._double_grads = self._get_double_grads(self._origin_main_program)
155
        self.training = True
156

157 158 159 160
        custom_white_list, custom_black_list = None, None
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
161
        # For AMP training
162 163 164
        self._amp_list = AutoMixedPrecisionLists(
            custom_white_list=custom_white_list,
            custom_black_list=custom_black_list)
165

166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184
    @LazyInitialized
    def _infer_program(self):
        """
        Lazy initialized property of infer_program.
        """
        return self._clone_for_test(self._origin_main_program)

    @LazyInitialized
    def _train_program(self):
        """
        Lazy initialized property of train_program.
        """
        train_program = self._append_backward_desc(self._origin_main_program)
        # Note: Only set grad type once after initializing train program. So we
        # put it here.
        self._set_grad_type(self._params, train_program)

        return train_program

185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
    @LazyInitialized
    @switch_to_static_graph
    def _infer_amp_program(self):
        """
        Lazy initialized property of infer_amp_program.
        """
        infer_amp_program = self._origin_main_program.clone()
        with program_guard(infer_amp_program):
            rewrite_program(infer_amp_program, self._amp_list)

        return infer_amp_program

    @LazyInitialized
    def _train_amp_program(self):
        """
        Lazy initialized property of train_amp_program.
        """
        return self._append_backward_desc(self._infer_amp_program)

204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
    @LazyInitialized
    @switch_to_static_graph
    def _infer_pure_fp16_program(self):
        """
        Lazy initialized property of _infer_pure_fp16_program.
        """
        infer_pure_fp16_program = self._origin_main_program.clone()
        with program_guard(infer_pure_fp16_program):
            cast_model_to_fp16(
                infer_pure_fp16_program, self._amp_list, use_fp16_guard=False)

        return infer_pure_fp16_program

    @LazyInitialized
    def _train_pure_fp16_program(self):
        """
        Lazy initialized property of _train_pure_fp16_program.
        """
        return self._append_backward_desc(self._infer_pure_fp16_program)

224 225 226 227 228 229
    @LazyInitialized
    def _infer_program_id(self):
        return _hash_with_id(self._infer_program, self)

    @LazyInitialized
    def _train_program_id(self):
230 231 232 233 234
        program_id = _hash_with_id(self._train_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id
235

236 237 238 239 240 241 242 243
    @LazyInitialized
    def _train_amp_program_id(self):
        program_id = _hash_with_id(self._train_amp_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id

244 245 246 247 248 249 250 251
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
        program_id = _hash_with_id(self._train_pure_fp16_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id

252 253 254 255 256 257 258 259 260 261 262 263
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

264
    @switch_to_static_graph
265
    def _append_backward_desc(self, main_program):
266 267
        # make sure all status of is_test are False in train mode.
        program = _change_is_test_status(main_program.clone(), is_test=False)
268
        targets = []
269
        for out in self._outputs.tolist():
270 271 272 273 274 275 276 277
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

        if targets and self._params:
            backward.gradients(targets=targets, inputs=[])

        return program

278 279 280 281 282 283 284 285 286 287
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
        The `@declarative` may only decorated a sub function which
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
288
            found_param = False
289
            for block in program.blocks:
290 291 292 293 294 295
                for op in block.ops:
                    if param.name in op.input_arg_names or param.name in op.output_arg_names:
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
296 297 298 299
                    break

        self._params = required_params

300 301 302 303 304 305
    def _get_double_grads(self, program):
        double_grads = []
        for block in program.blocks:
            for name in block.vars:
                if "@GRAD" in name:
                    var_desc = block.vars[name].desc
J
Jiabin Yang 已提交
306 307 308 309 310 311 312 313 314 315 316
                    var_base = None
                    if not core._in_eager_mode():
                        var_base = core.VarBase(var_desc.dtype(),
                                                var_desc.shape(),
                                                var_desc.name(),
                                                var_desc.type(), False)
                    else:
                        var_base = core.eager.Tensor(var_desc.dtype(),
                                                     var_desc.shape(),
                                                     var_desc.name(),
                                                     var_desc.type(), False)
317
                    double_grads.append(var_base)
318
        return self._valid_vars(double_grads)
319

320
    def _get_end_op_index(self):
321 322 323 324 325 326
        if _in_amp_guard():
            infer_program = self._infer_amp_program
        elif _in_pure_fp16_guard():
            infer_program = self._infer_pure_fp16_program
        else:
            infer_program = self._infer_program
327 328
        return infer_program.desc.block(0).op_size()

329 330
    def __call__(self, inputs):
        in_vars, out_vars = self._prepare(inputs)
331 332

        attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
333 334
                 0, 'end_op_index', self._get_end_op_index(), 'is_test',
                 not self.training, 'program_id', self.program_id)
335 336 337

        self._cast_fp16_if_pure_fp16(in_vars)

W
wanghuancoder 已提交
338
        _C_ops.run_program(
339 340 341 342
            self._valid_vars(in_vars),
            self._valid_vars(self._params),
            self._valid_vars(out_vars), self._tmp_scope_vec, self._double_grads,
            *attrs)
343
        self.drop_scope_if_no_grad()
344 345
        restored_nest_out = self._restore_out(out_vars)
        return self._remove_no_value(restored_nest_out)
346

347 348 349 350 351 352 353 354 355 356
    def _cast_fp16_if_pure_fp16(self, in_vars):
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
                if (self.program.global_block().has_var(name) and
                        self.program.global_block().var(name).dtype ==
                        paddle.float16):
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name

357 358 359 360 361
    def drop_scope_if_no_grad(self):
        tracer = framework._dygraph_tracer()
        if self.training and not tracer._has_grad:
            self._tmp_scope_vec.value().get_scope().drop_kids()

362 363
    @property
    def program(self):
364
        if self.training:
365 366 367 368 369 370
            if _in_amp_guard():
                return self._train_amp_program
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program
            else:
                return self._train_program
371 372
        else:
            return self._infer_program
373

374 375
    @property
    def program_id(self):
376
        if self.training:
377 378 379 380 381 382
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
383 384
        else:
            return self._infer_program_id
385

386 387 388 389 390
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
391 392
        # Flatten inputs with nested structure into single list.
        flatten_inputs = flatten(inputs)
393 394
        # Convert variable into VarBase and feed in training data.
        input_vars = []
395
        expected_place = framework._current_expected_place()
396
        for i, value in enumerate(flatten_inputs):
397
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
                var = None
                if not core._in_eager_mode():
                    var = core.VarBase(
                        value=value,
                        name=self._inputs[i].desc.name(),
                        persistable=False,
                        place=expected_place,
                        zero_copy=True)
                else:
                    var = core.eager.Tensor(
                        value=value,
                        name=self._inputs[i].desc.name(),
                        persistable=False,
                        place=expected_place,
                        zero_copy=True)
            elif isinstance(value, (core.VarBase, core.eager.Tensor)):
414 415 416 417 418 419 420
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
                        expected_place):
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
421 422
                else:
                    var = value
423
                var.name = self._inputs[i].desc.name()
424 425 426
            else:
                continue
            input_vars.append(var)
427

428 429
        def create_out(var_id):
            var = self._outputs[var_id]
430
            assert isinstance(var, framework.Variable)
431
            var_desc = var.desc
J
Jiabin Yang 已提交
432 433 434 435 436 437 438 439 440 441
            varbase = None
            if not core._in_eager_mode():
                var_base = core.VarBase(var_desc.dtype(),
                                        var_desc.shape(),
                                        var_desc.name(), var_desc.type(), False)
            else:
                var_base = core.eager.Tensor(var_desc.dtype(),
                                             var_desc.shape(),
                                             var_desc.name(),
                                             var_desc.type(), False)
442 443 444 445 446 447
            return var_base

        # Create VarBase to receive output data.
        out_vars = list(map(create_out, self._outputs.var_ids))

        return input_vars, out_vars
448

449
    def _create_scope_vec(self):
450
        # Hold forward variables
J
Jiabin Yang 已提交
451 452 453 454 455 456 457 458 459 460 461 462 463
        tmp_scope_vec = None
        if not core._in_eager_mode():
            tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
                                         "program_out_scope",
                                         core.VarDesc.VarType.STEP_SCOPES, True)
            # TODO(jiabin): Support this later.
            # else:
            #     tmp_scope_vec = core.eager.Tensor(core.VarDesc.VarType.FP32, [],
            #                                 "program_out_scope",
            #                                 core.VarDesc.VarType.STEP_SCOPES, True)

            inner_scope = core.Scope()
            tmp_scope_vec.value().set_scope(inner_scope)
464
        return tmp_scope_vec
465

466 467 468 469 470 471 472 473 474
    def _restore_out(self, out_vars):
        """
        Restores same nested outputs by only replacing the Variable with VarBase.
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
475
        if outs is not None and len(outs) == 1:
476 477 478 479
            outs = outs[0]

        return outs

480 481 482 483
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

484
    def _is_no_value(self, var):
J
Jiabin Yang 已提交
485 486
        if isinstance(var,
                      (core.VarBase, core.eager.Tensor)) and var.shape == [1]:
487 488
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
489 490 491 492 493 494 495
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
J
Jiabin Yang 已提交
496
        if isinstance(out_vars, (core.VarBase, core.eager.Tensor)):
497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
                res = tuple(
                    var for var in out_vars if not self._is_no_value(var))
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

            has_removed = (len(out_vars) > len(res))
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

519
    def _set_grad_type(self, params, train_program):
520 521 522 523 524 525 526 527
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
        # set param grad VarBase by forward VarBase(LoDTensor)
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
528
            grad_var = train_program.desc.block(0).find_var(
529 530 531 532 533 534
                cpt.to_bytes(grad_name))
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

535 536 537 538 539 540 541 542 543 544 545 546 547
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

548 549 550 551 552 553 554 555 556 557 558 559
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
            1. all parameters in self._params should be type `framework.ParamBase` which are created in dygraph.
            2. all parameters from transformed program can be found in self._params.
               Because they share same data with ParamBase of original dygraph.
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
                % type(self._params))

560 561 562
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
J
Jiabin Yang 已提交
563
            if not isinstance(var, (core.VarBase, core.eager.Tensor)):
564
                raise TypeError(
565 566 567
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'.
                    format(i, type(var)))
            param_and_buffer_names_set.add(var.name)
568 569

        for block in main_program.blocks:
570
            for name, var in six.iteritems(block.vars):
571
                if isinstance(var, framework.Parameter):
572
                    if name not in param_and_buffer_names_set:
573
                        raise ValueError(
574 575 576 577 578 579
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
580 581
                            % name)

582 583 584 585 586 587 588 589
    def _valid_vars(self, vars):
        """
        Note: run_program_op.InferShape requires `X`/'Out' not be null.
        But it's common in dy2static, fake varBase is created to handle the
        problem.
        """
        return vars if vars else self.__fake_vars

590

591
def _create_fake_var():
592
    """
593
    Create a fake_var (force on CPU) to handle empty input or output
594
    """
J
Jiabin Yang 已提交
595 596 597 598 599 600 601 602 603 604 605 606
    if not core._in_eager_mode():
        return [
            core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
                         core.VarDesc.VarType.RAW, False)
        ]
    else:
        return []
        # TODO(jiabin): Support this later
        # return [
        #     core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var",
        #                 core.VarDesc.VarType.RAW, False)
        # ]
607 608 609 610 611 612 613


def partial_program_from(concrete_program):
    inputs = concrete_program.inputs
    if inputs and isinstance(inputs[0], layers.Layer):
        inputs = inputs[1:]

614 615 616
    return PartialProgramLayer(
        concrete_program.main_program, inputs, concrete_program.outputs,
        concrete_program.parameters, **concrete_program.kwargs)