partial_program.py 24.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import numpy as np
17
import six
18

19
import paddle
20
from paddle.fluid import framework, backward, core, program_guard
21
from paddle.fluid.dygraph import layers
22
from paddle.fluid.dygraph.base import switch_to_static_graph
23
from paddle.fluid.dygraph.dygraph_to_static import logging_utils
24
from paddle.fluid.dygraph.dygraph_to_static.return_transformer import RETURN_NO_VALUE_MAGIC_NUM
25 26
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
27 28
from paddle.fluid.layers.utils import _hash_with_id
from paddle.fluid.compiler import BuildStrategy
29
from paddle.fluid.contrib.mixed_precision.decorator import AutoMixedPrecisionLists
30 31
from paddle.fluid.contrib.mixed_precision.fp16_utils import rewrite_program, cast_model_to_fp16
from paddle.fluid.dygraph.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
32
import paddle.compat as cpt
W
wanghuancoder 已提交
33
from paddle import _C_ops
34

35 36 37 38 39 40 41 42 43

class NestSequence(object):
    """
    A wrapper class that easily to flatten and restore the nest structure of
    given sequence.
    """

    def __init__(self, raw_input, need_check=False):
        self.__raw_input = raw_input
44
        self.__input_list = self.tolist()
45 46 47 48 49 50 51 52 53 54 55 56 57
        self.__var_ids = self._get_var_ids()
        self._check_non_variable(need_check)

    def tolist(self):
        """
        Flattens the nested sequences into single list.
        """
        return flatten(self.__raw_input)

    def restore(self, value_list):
        """
        Restores the nested sequence from value list.
        """
58
        assert len(self.__input_list) == len(value_list)
59 60 61 62
        return pack_sequence_as(self.__raw_input, value_list)

    def _get_var_ids(self):
        var_ids = []
63
        for idx, var in enumerate(self.__input_list):
64 65
            if isinstance(
                    var, (framework.Variable, core.VarBase, core.eager.Tensor)):
66 67 68 69 70 71 72 73 74 75
                var_ids.append(idx)

        return var_ids

    def _check_non_variable(self, need_check):
        """
        Raises warning if output of traced function contains non-tensor type values.
        """
        if need_check:
            warning_types = set()
76
            for var in self.__input_list:
77 78 79
                if not isinstance(
                        var,
                    (framework.Variable, core.VarBase, core.eager.Tensor)):
80 81
                    warning_types.add(type(var))
            if warning_types:
82
                logging_utils.warn(
83 84 85 86 87 88 89 90 91 92
                    "Output of traced function contains non-tensor type values: {}. "
                    "Currently, We don't support to update them while training and will return "
                    "what we first saw. Please try to return them as tensor.".
                    format(list(warning_types)))

    @property
    def var_ids(self):
        return self.__var_ids

    def __getitem__(self, item):
93
        return self.__input_list[item]
94

95

96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
class LazyInitialized(object):
    """
    Descriptor to implement lazy initialization of property.
    """

    def __init__(self, function):
        self.function = function

    def __get__(self, instance, cls):
        val = self.function(instance)
        setattr(instance, self.function.__name__, val)
        return val


def _change_is_test_status(program, is_test):
    # change all `is_test` attributes
    for block in program.blocks:
        for op in block.ops:
            if op.has_attr('is_test'):
                op._set_attr('is_test', is_test)
    return program


119
class PartialProgramLayer:
120 121 122 123 124
    """
    PartialProgramLayer wraps all the ops from layers decorated by `@declarative`
    and execute them as a static subgraph.

    .. note::
125 126 127
        **1. This is a very low level API. Users should not use this API
             directly. Please use `partial_program_from(concrete_program)`
             to create it.
128 129 130 131 132 133 134 135 136 137 138 139
        **2. LoDTensorArray is not currently supported in the output.

    Args:
        main_program(Program): The main program that contains ops need to be executed.
        inputs(list[Variable]): The input list of the decorated function by `@declarative`.
        outputs(list[Variable]): The output list of the decorated function by `@declarative`.
        parameters(list[VarBase]|None): All trainable parameters included in the program. Default None.

    Returns:
        Layer: A Layer object that run all ops internally in static mode.
    """

140 141 142 143 144
    def __init__(self,
                 main_program,
                 inputs,
                 outputs,
                 parameters=None,
145
                 **kwargs):
146
        super(PartialProgramLayer, self).__init__()
147 148
        self._inputs = NestSequence(inputs)
        self._outputs = NestSequence(outputs, need_check=True)
149
        self._params = parameters if parameters is not None else []
150

151 152 153
        self._build_strategy = kwargs.get('build_strategy', BuildStrategy())
        assert isinstance(self._build_strategy, BuildStrategy)

154
        self._origin_main_program = self._verify_program(main_program)
155
        self._tmp_scope_vec = self._create_scope_vec()
156 157 158
        self._cuda_graph_vec = self._create_cuda_graph_vec()
        self._cuda_graph_capture_mode = ""
        self._cuda_graph_pool_id = 0
159
        # Set default mode to train
160
        self.training = True
161

162 163 164 165
        custom_white_list, custom_black_list = None, None
        tracer = framework._dygraph_tracer()
        if tracer:
            custom_white_list, custom_black_list = tracer._get_amp_op_list()
166
        # For AMP training
167 168 169
        self._amp_list = AutoMixedPrecisionLists(
            custom_white_list=custom_white_list,
            custom_black_list=custom_black_list)
170

171 172 173 174 175 176 177 178
    @LazyInitialized
    def __fake_vars(self):
        return _create_fake_var()

    @LazyInitialized
    def _double_grads(self):
        return self._get_double_grads(self._origin_main_program)

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
    @LazyInitialized
    def _infer_program(self):
        """
        Lazy initialized property of infer_program.
        """
        return self._clone_for_test(self._origin_main_program)

    @LazyInitialized
    def _train_program(self):
        """
        Lazy initialized property of train_program.
        """
        train_program = self._append_backward_desc(self._origin_main_program)
        # Note: Only set grad type once after initializing train program. So we
        # put it here.
        self._set_grad_type(self._params, train_program)

        return train_program

198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
    @LazyInitialized
    @switch_to_static_graph
    def _infer_amp_program(self):
        """
        Lazy initialized property of infer_amp_program.
        """
        infer_amp_program = self._origin_main_program.clone()
        with program_guard(infer_amp_program):
            rewrite_program(infer_amp_program, self._amp_list)

        return infer_amp_program

    @LazyInitialized
    def _train_amp_program(self):
        """
        Lazy initialized property of train_amp_program.
        """
215 216 217
        train_amp_program = self._append_backward_desc(self._infer_amp_program)
        self._set_grad_type(self._params, train_amp_program)
        return train_amp_program
218

219 220 221 222 223 224 225 226
    @LazyInitialized
    @switch_to_static_graph
    def _infer_pure_fp16_program(self):
        """
        Lazy initialized property of _infer_pure_fp16_program.
        """
        infer_pure_fp16_program = self._origin_main_program.clone()
        with program_guard(infer_pure_fp16_program):
227 228 229
            cast_model_to_fp16(infer_pure_fp16_program,
                               self._amp_list,
                               use_fp16_guard=False)
230 231 232 233 234 235 236 237

        return infer_pure_fp16_program

    @LazyInitialized
    def _train_pure_fp16_program(self):
        """
        Lazy initialized property of _train_pure_fp16_program.
        """
238 239 240 241
        train_pure_fp16_program = self._append_backward_desc(
            self._infer_pure_fp16_program)
        self._set_grad_type(self._params, train_pure_fp16_program)
        return train_pure_fp16_program
242

243 244 245 246 247 248
    @LazyInitialized
    def _infer_program_id(self):
        return _hash_with_id(self._infer_program, self)

    @LazyInitialized
    def _train_program_id(self):
249 250 251 252 253
        program_id = _hash_with_id(self._train_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id
254

255 256 257 258 259 260 261 262
    @LazyInitialized
    def _train_amp_program_id(self):
        program_id = _hash_with_id(self._train_amp_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id

263 264 265 266 267 268 269 270
    @LazyInitialized
    def _train_pure_fp16_program_id(self):
        program_id = _hash_with_id(self._train_pure_fp16_program, self)
        core._set_cached_executor_build_strategy(program_id,
                                                 self._build_strategy)

        return program_id

271 272 273 274 275 276 277 278 279 280 281 282
    def _verify_program(self, main_program):
        """
        Verify that the program parameter is initialized, prune some unused params,
        and remove redundant op callstack.
        """
        # 1. Check all params from main program can be found in self._params
        self._check_params_all_inited(main_program)
        # 2. Prune the parameters not used anywhere in the program.
        self._prune_unused_params(main_program)

        return main_program

283
    @switch_to_static_graph
284
    def _append_backward_desc(self, main_program):
285 286
        # make sure all status of is_test are False in train mode.
        program = _change_is_test_status(main_program.clone(), is_test=False)
287
        targets = []
288
        for out in self._outputs.tolist():
289 290 291 292 293 294 295 296
            if isinstance(out, framework.Variable):
                targets.append(program.global_block().var(out.name))

        if targets and self._params:
            backward.gradients(targets=targets, inputs=[])

        return program

297 298 299 300 301 302 303 304 305 306
    def _prune_unused_params(self, program):
        """
        Prune the parameters not used anywhere in the program.
        The `@declarative` may only decorated a sub function which
        contains some unused parameters created in `__init__`.
        So prune these parameters to avoid unnecessary operations in
        `run_program_op`.
        """
        required_params = []
        for param in self._params:
307
            found_param = False
308
            for block in program.blocks:
309 310 311 312 313 314
                for op in block.ops:
                    if param.name in op.input_arg_names or param.name in op.output_arg_names:
                        required_params.append(param)
                        found_param = True
                        break
                if found_param:
315 316 317 318
                    break

        self._params = required_params

319 320 321 322 323 324
    def _get_double_grads(self, program):
        double_grads = []
        for block in program.blocks:
            for name in block.vars:
                if "@GRAD" in name:
                    var_desc = block.vars[name].desc
J
Jiabin Yang 已提交
325
                    var_base = None
J
Jiabin Yang 已提交
326
                    if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
327 328 329 330 331 332 333 334 335
                        var_base = core.VarBase(var_desc.dtype(),
                                                var_desc.shape(),
                                                var_desc.name(),
                                                var_desc.type(), False)
                    else:
                        var_base = core.eager.Tensor(var_desc.dtype(),
                                                     var_desc.shape(),
                                                     var_desc.name(),
                                                     var_desc.type(), False)
336
                    double_grads.append(var_base)
337
        return self._valid_vars(double_grads)
338

339
    def _get_end_op_index(self):
340 341 342 343 344 345
        if _in_amp_guard():
            infer_program = self._infer_amp_program
        elif _in_pure_fp16_guard():
            infer_program = self._infer_pure_fp16_program
        else:
            infer_program = self._infer_program
346 347
        return infer_program.desc.block(0).op_size()

348 349
    def __call__(self, inputs):
        in_vars, out_vars = self._prepare(inputs)
350

351
        attrs = [
352 353 354 355
            'global_block',
            self.program.desc.block(0), 'start_op_index', 0, 'end_op_index',
            self._get_end_op_index(), 'is_test', not self.training,
            'program_id', self.program_id
356 357 358 359 360
        ]
        if self._cuda_graph_capture_mode:
            attrs.extend(
                ('cuda_graph_capture_mode', self._cuda_graph_capture_mode,
                 'cuda_graph_pool_id', self._cuda_graph_pool_id))
361 362 363

        self._cast_fp16_if_pure_fp16(in_vars)

364 365 366 367
        _C_ops.run_program(self._valid_vars(in_vars),
                           self._valid_vars(self._params),
                           self._valid_vars(out_vars), self._tmp_scope_vec,
                           self._double_grads, self._cuda_graph_vec, *attrs)
368
        self.drop_scope_if_no_grad()
369 370
        restored_nest_out = self._restore_out(out_vars)
        return self._remove_no_value(restored_nest_out)
371

372 373 374 375
    def _cast_fp16_if_pure_fp16(self, in_vars):
        if _in_pure_fp16_guard():
            for i, var in enumerate(in_vars):
                name = var.name
376 377 378
                if (self.program.global_block().has_var(name)
                        and self.program.global_block().var(name).dtype
                        == paddle.float16):
379 380 381
                    in_vars[i] = var.astype('float16')
                    in_vars[i].name = name

382 383
    def drop_scope_if_no_grad(self):
        tracer = framework._dygraph_tracer()
384 385
        scope = self._tmp_scope_vec.value().get_scope() if isinstance(
            self._tmp_scope_vec, (core.VarBase)) else self._tmp_scope_vec[0]
386
        if self.training and not tracer._has_grad:
387
            scope.drop_kids()
388

389 390
    @property
    def program(self):
391
        if self.training:
392 393 394 395 396 397
            if _in_amp_guard():
                return self._train_amp_program
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program
            else:
                return self._train_program
398 399
        else:
            return self._infer_program
400

401 402
    @property
    def program_id(self):
403
        if self.training:
404 405 406 407 408 409
            if _in_amp_guard():
                return self._train_amp_program_id
            elif _in_pure_fp16_guard():
                return self._train_pure_fp16_program_id
            else:
                return self._train_program_id
410 411
        else:
            return self._infer_program_id
412

413 414 415 416 417
    def _prepare(self, inputs):
        """
        Prepare inputs, outputs, attrs.
        """
        assert isinstance(inputs, (tuple, list))
418 419
        # Flatten inputs with nested structure into single list.
        flatten_inputs = flatten(inputs)
420 421
        # Convert variable into VarBase and feed in training data.
        input_vars = []
422
        expected_place = framework._current_expected_place()
423
        for i, value in enumerate(flatten_inputs):
424
            if isinstance(value, np.ndarray):
J
Jiabin Yang 已提交
425
                var = None
J
Jiabin Yang 已提交
426
                if not framework._in_eager_mode_:
427 428 429 430 431
                    var = core.VarBase(value=value,
                                       name=self._inputs[i].desc.name(),
                                       persistable=False,
                                       place=expected_place,
                                       zero_copy=True)
J
Jiabin Yang 已提交
432
                else:
433 434 435 436 437
                    var = core.eager.Tensor(value=value,
                                            name=self._inputs[i].desc.name(),
                                            persistable=False,
                                            place=expected_place,
                                            zero_copy=True)
J
Jiabin Yang 已提交
438
            elif isinstance(value, (core.VarBase, core.eager.Tensor)):
439 440 441 442 443 444 445
                # NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
                # into CUDAPlace when it's as input of multi Ops. so we move it in advance
                # to avoid this problem.
                if value.stop_gradient and not value.place._equals(
                        expected_place):
                    var = value._copy_to(expected_place, False)
                    var.stop_gradient = True
446 447
                else:
                    var = value
448
                var.name = self._inputs[i].desc.name()
449 450 451
            else:
                continue
            input_vars.append(var)
452

453 454
        def create_out(var_id):
            var = self._outputs[var_id]
455
            assert isinstance(var, framework.Variable)
456
            var_desc = var.desc
J
Jiabin Yang 已提交
457
            varbase = None
J
Jiabin Yang 已提交
458
            if not framework._in_eager_mode_:
459
                var_base = core.VarBase(var_desc.dtype(), var_desc.shape(),
J
Jiabin Yang 已提交
460 461
                                        var_desc.name(), var_desc.type(), False)
            else:
462 463 464
                var_base = core.eager.Tensor(var_desc.dtype(), var_desc.shape(),
                                             var_desc.name(), var_desc.type(),
                                             False)
465 466 467 468 469 470
            return var_base

        # Create VarBase to receive output data.
        out_vars = list(map(create_out, self._outputs.var_ids))

        return input_vars, out_vars
471

472
    def _create_scope_vec(self):
473
        # Hold forward variables
J
Jiabin Yang 已提交
474
        tmp_scope_vec = None
475
        inner_scope = core.Scope()
J
Jiabin Yang 已提交
476
        if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
477 478 479 480
            tmp_scope_vec = core.VarBase(core.VarDesc.VarType.FP32, [],
                                         "program_out_scope",
                                         core.VarDesc.VarType.STEP_SCOPES, True)
            tmp_scope_vec.value().set_scope(inner_scope)
481 482
        else:
            tmp_scope_vec = [inner_scope]
483
        return tmp_scope_vec
484

485 486 487 488 489 490
    def _create_cuda_graph_vec(self):
        var = core.VarBase(core.VarDesc.VarType.FP32, [], "cuda_graph",
                           core.VarDesc.VarType.RAW, True)
        var.stop_gradient = True
        return var

491 492 493 494 495 496 497 498 499
    def _restore_out(self, out_vars):
        """
        Restores same nested outputs by only replacing the Variable with VarBase.
        """

        flatten_outputs = self._outputs.tolist()
        for i, idx in enumerate(self._outputs.var_ids):
            flatten_outputs[idx] = out_vars[i]
        outs = self._outputs.restore(flatten_outputs)
500
        if outs is not None and len(outs) == 1:
501 502 503 504
            outs = outs[0]

        return outs

505 506 507 508
    @switch_to_static_graph
    def _clone_for_test(self, main_program):
        return main_program.clone(for_test=True)

509
    def _is_no_value(self, var):
J
Jiabin Yang 已提交
510 511
        if isinstance(var,
                      (core.VarBase, core.eager.Tensor)) and var.shape == [1]:
512 513
            # NOTE: .numpy() will insert MemcpySync operation, it hits performance.
            if var.numpy()[0] == RETURN_NO_VALUE_MAGIC_NUM:
514 515 516 517 518 519 520
                return True
        return False

    def _remove_no_value(self, out_vars):
        """
        Removes invalid value for various-length return statement
        """
J
Jiabin Yang 已提交
521
        if isinstance(out_vars, (core.VarBase, core.eager.Tensor)):
522 523 524 525 526
            if self._is_no_value(out_vars):
                return None
            return out_vars
        elif isinstance(out_vars, (tuple, list)):
            if isinstance(out_vars, tuple):
527 528
                res = tuple(var for var in out_vars
                            if not self._is_no_value(var))
529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
            else:
                # isinstance(out_vars, list)
                res = [var for var in out_vars if not self._is_no_value(var)]

            has_removed = (len(out_vars) > len(res))
            # len(out_vars) > len(res) means we have removed var. This is
            # preventing out_vars is empty or just one element at the beginning
            if len(res) == 0 and has_removed:
                return None
            elif len(res) == 1 and has_removed:
                return res[0]
            return res

        return out_vars

544
    def _set_grad_type(self, params, train_program):
545 546 547 548 549 550 551 552
        # NOTE: if user set sparse gradient mode, the param's gradient
        # will be SelectedRows, not LoDTensor. But tracer will just
        # set param grad VarBase by forward VarBase(LoDTensor)
        # If we don't change grad_var type here, RunProgramOp need
        # transform SelectedRows to LoDTensor forcibly, it may not
        # be user wanted result.
        for param in params:
            grad_name = param.name + core.grad_var_suffix()
553
            grad_var = train_program.desc.block(0).find_var(
554 555 556 557 558 559
                cpt.to_bytes(grad_name))
            # NOTE: cannot find var desc maybe no problem, such as in batch_norm
            if grad_var is None:
                continue
            param._set_grad_type(grad_var.type())

560 561 562 563 564 565 566 567 568 569 570 571 572
    def _remove_op_call_stack(self, main_program):
        """
        Remove op's python call stack with redundant low-level error messages related to
        transforamtions to avoid confusing users.
        """
        assert isinstance(main_program, framework.Program)
        for block in main_program.blocks:
            for op in block.ops:
                if op.has_attr("op_callstack"):
                    op._remove_attr("op_callstack")

        return main_program

573 574 575 576 577 578 579 580 581 582 583 584
    def _check_params_all_inited(self, main_program):
        """
        Check all params from main program are already initialized, see details as follows:
            1. all parameters in self._params should be type `framework.ParamBase` which are created in dygraph.
            2. all parameters from transformed program can be found in self._params.
               Because they share same data with ParamBase of original dygraph.
        """
        if not isinstance(self._params, (list, tuple)):
            raise TypeError(
                "Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
                % type(self._params))

585 586 587
        param_and_buffer_names_set = set()
        for i, var in enumerate(self._params):
            # self._params constains parameters and buffers with persistable=True.
J
Jiabin Yang 已提交
588
            if not isinstance(var, (core.VarBase, core.eager.Tensor)):
589
                raise TypeError(
590 591
                    'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'
                    .format(i, type(var)))
592
            param_and_buffer_names_set.add(var.name)
593 594

        for block in main_program.blocks:
595
            for name, var in six.iteritems(block.vars):
596
                if isinstance(var, framework.Parameter):
597
                    if name not in param_and_buffer_names_set:
598
                        raise ValueError(
599 600 601 602 603 604
                            "\n\tWe don't support to define layer with parameters in the function decorated by `@to_static`."
                            "\n\tBut we found parameter(%s) was created in the decorated function."
                            "\n"
                            "\n\tRevise suggestion: "
                            "\n\t\t1. Please ensure all your sublayers are inheritted from nn.Layer."
                            "\n\t\t2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
605 606
                            % name)

607 608 609 610 611 612 613 614
    def _valid_vars(self, vars):
        """
        Note: run_program_op.InferShape requires `X`/'Out' not be null.
        But it's common in dy2static, fake varBase is created to handle the
        problem.
        """
        return vars if vars else self.__fake_vars

615

616
def _create_fake_var():
617
    """
618
    Create a fake_var (force on CPU) to handle empty input or output
619
    """
J
Jiabin Yang 已提交
620
    if not framework._in_eager_mode_:
J
Jiabin Yang 已提交
621 622 623 624 625
        return [
            core.VarBase(core.VarDesc.VarType.FP32, [], "Fake_var",
                         core.VarDesc.VarType.RAW, False)
        ]
    else:
626 627 628 629
        return [
            core.eager.Tensor(core.VarDesc.VarType.FP32, [], "Fake_var",
                              core.VarDesc.VarType.RAW, False)
        ]
630 631 632 633 634 635 636


def partial_program_from(concrete_program):
    inputs = concrete_program.inputs
    if inputs and isinstance(inputs[0], layers.Layer):
        inputs = inputs[1:]

637 638 639 640
    return PartialProgramLayer(concrete_program.main_program, inputs,
                               concrete_program.outputs,
                               concrete_program.parameters,
                               **concrete_program.kwargs)