auto_parallel_recompute.py 20.5 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15 16
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging

17 18
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from paddle.fluid import core, framework, unique_name
19 20 21 22 23 24
from paddle.fluid.backward import (
    ProgramStats,
    _append_grad_suffix_,
    _find_op_path_,
    _get_no_grad_set_name,
    _rename_arg_,
25
)
26

27
from ..auto_parallel.dist_attribute import OperatorDistAttr
28 29 30 31 32 33 34 35 36
from ..auto_parallel.utils import (
    get_loss_op,
    insert_dependencies_for_two_ops,
    is_backward_op,
    is_recompute_op,
    naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
    set_dist_op_desc_original_id,
    set_var_dist_attr,
)
37 38
from .pass_base import PassBase, register_pass

39 40 41

class RecomputeState(ProgramStats):
    def __init__(self, block, ops):
42
        super().__init__(block=block, ops=ops)
43
        self.seg_op_deps = {}
44 45 46 47 48 49 50 51 52 53
        self._checkpoints = []
        self._reserved_vars = []

    @property
    def checkpoints(self):
        return self._checkpoints

    @property
    def reserved_vars(self):
        return self._reserved_vars
54

55 56 57 58 59 60 61 62 63
    def is_recompute(self):
        return any([is_recompute_op(op) for op in self.ops])

    def build_states(self):
        for i, op in enumerate(self.ops):
            if is_backward_op(op):
                break

            for name in op.input_arg_names:
64 65 66 67 68 69 70
                if name in self.var_op_deps:
                    self.var_op_deps[name]["var_as_input_ops"].extend([i])
                else:
                    self.var_op_deps[name] = {}
                    self.var_op_deps[name]["var_as_input_ops"] = [i]
                    self.var_op_deps[name]["var_as_output_ops"] = []

71
            for name in op.output_arg_names:
72 73 74 75 76 77 78
                if name in self.var_op_deps:
                    self.var_op_deps[name]["var_as_output_ops"].extend([i])
                else:
                    self.var_op_deps[name] = {}
                    self.var_op_deps[name]["var_as_input_ops"] = []
                    self.var_op_deps[name]["var_as_output_ops"] = [i]

79 80
            if not is_recompute_op(op):
                self._checkpoints.extend(op.output_arg_names)
81 82 83 84 85 86 87 88 89 90 91
                continue

            seg_name = op.attr('op_namescope')
            if seg_name not in self.seg_op_deps:
                self.seg_op_deps[seg_name] = [i]
            else:
                assert (
                    self.seg_op_deps[seg_name][-1] + 1 == i
                ), "The recompute segment's ops should be continuous"
                self.seg_op_deps[seg_name].extend([i])

92
    def get_recompute_segments(self, no_recompute_segments=[]):
93
        segments = []
94 95 96 97 98 99 100 101 102 103 104
        for segment_idx in self.seg_op_deps.values():
            if len(segment_idx) == 1:
                continue
            segments.append([segment_idx[0], segment_idx[-1] + 1])
            self._checkpoints.extend(self.ops[segment_idx[-1]].output_arg_names)

        for i in reversed(sorted(no_recompute_segments)):
            assert i < len(
                segments
            ), "the no_recompute_segments idx [{}] should be lower the number of segment [{}]".format(
                i, len(segments)
105
            )
106
            segments.pop(i)
107

108
        return segments
109 110 111

    def modify_forward_desc_for_recompute(self, dist_context):
        """
112
        If program's foward part has 'dropout' op, this function will insert
113 114
        a seed op before it to guarantee that two dropout op have the same outputs.
        """
115
        op_types = [op.type for op in self.ops]
116 117 118 119
        if "dropout" not in op_types:
            return

        op_idx = 0
120 121
        while op_idx < len(self.ops):
            cur_op = self.ops[op_idx]
122 123
            if "grad" in cur_op.type:
                break
124 125 126 127
            if cur_op.type == "seed":
                self._reserved_vars.extend(cur_op.output_arg_names)
                op_idx += 1
                continue
128 129 130 131 132 133 134 135 136 137
            if cur_op.type != "dropout":
                op_idx += 1
                continue
            if cur_op.input("Seed") is not None and len(cur_op.input("Seed")):
                op_idx += 1
                continue

            cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op)
            # insert seed op to guarantee that two dropout op have the same outputs
            op_unique_name = unique_name.generate("seed")
138 139 140
            var_unique_name = unique_name.generate_with_ignorable_key(
                ".".join([op_unique_name, 'tmp'])
            )
141 142
            self._reserved_vars.append(var_unique_name)
            seed_var = self.block.create_var(
143 144 145 146
                name=var_unique_name,
                dtype='int32',
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
147 148
                stop_gradient=False,
            )
149 150 151 152

            # set new seed_var's dist_attr
            ref_dims_mapping = [-1]
            ref_process_mesh = cur_op_dist_attr.process_mesh
153 154 155 156 157 158 159 160 161
            seed_var_dist_attr = set_var_dist_attr(
                dist_context, seed_var, ref_dims_mapping, ref_process_mesh
            )

            seed = (
                0
                if cur_op.attr("fix_seed") is False
                else int(cur_op.attr("seed"))
            )
162
            # TODO add dependency for seed op to ensure it be issued just before recompute.
163
            seed_op = self.block._insert_op_without_sync(
164 165 166 167
                index=cur_op.idx,
                type="seed",
                inputs={},
                outputs={"Out": seed_var},
168 169
                attrs={"seed": seed, "force_cpu": True},
            )
170
            seed_op._set_attr('op_namescope', cur_op.attr('op_namescope'))
171 172
            # set new seed op's dist_attr
            naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
173 174
                seed_op, ref_process_mesh, ref_dims_mapping, dist_context
            )
175 176

            # modify dropout op's desc
177
            self.ops.insert(op_idx, seed_op)
178
            cur_op.desc.set_input("Seed", [var_unique_name])
179 180
            cur_op._remove_attr("fix_seed")
            cur_op._remove_attr("seed")
181 182 183
            cur_op_dist_attr.set_input_dist_attr(
                seed_var.name, seed_var_dist_attr
            )
184 185
            op_idx += 2

186
        self.block._sync_with_cpp()
187

188 189 190 191 192 193 194 195

def _find_op_index(block, cur_op):
    for idx in range(block.desc.op_size()):
        if cur_op.desc == block.desc.op(idx):
            return idx
    return -1


196
def _get_stop_gradients(program, no_grad_set=None):
197
    """get no grad var"""
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212
    if no_grad_set is None:
        no_grad_set = set()
    else:
        no_grad_set = _get_no_grad_set_name(no_grad_set)

    no_grad_set_name = set()
    for var in program.list_vars():
        if "@GRAD" in var.name:
            break
        if var.stop_gradient:
            no_grad_set_name.add(_append_grad_suffix_(var.name))
    no_grad_set_name.update(list(map(_append_grad_suffix_, no_grad_set)))
    return no_grad_set_name


213
def _add_needed_descs_to_block(
214
    descs, block, main_block, vars_should_be_hold, dist_context
215
):
216 217 218 219 220
    """
    Get the recomputed ops which will insert the backward part
    """
    if len(descs) == 0:
        return []
221

222 223 224 225 226 227 228 229 230 231
    result_descs = []
    for desc in descs:
        if isinstance(desc, framework.Operator):
            desc = desc.desc
        if isinstance(desc, tuple):
            desc = desc[0]
        is_needed = False
        for name in desc.output_arg_names():
            if main_block.has_var(name) and main_block.var(name).persistable:
                continue
232
            if name not in vars_should_be_hold:
233 234 235 236 237
                is_needed = True
        if is_needed:
            new_op_desc = block.desc.append_op()
            new_op_desc.copy_from(desc)
            set_dist_op_desc_original_id(new_op_desc, desc, dist_context)
238
            new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
239 240 241 242
            result_descs.append(new_op_desc)
    return result_descs


243 244 245 246 247 248 249 250
def _find_op_path(main_program, loss, no_grad_set=None):
    no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
    op_path = _find_op_path_(
        main_program.global_block(), [loss], [], no_grad_set_name
    )
    return op_path


251 252 253
@register_pass("auto_parallel_recompute")
class RecomputePass(PassBase):
    def __init__(self):
254
        super().__init__()
255 256 257
        self.set_attr("loss", None)
        self.set_attr("dist_context", None)
        self.set_attr("no_grad_set", None)
258
        self.set_attr("no_recompute_segments", [])
259 260 261 262 263 264 265 266 267 268 269

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
        if self.get_attr("loss") is None:
            return False
        return True

    def _check_conflict(self, other_pass):
        return True

270
    def _apply_single_impl(self, main_program, startup_program, context):
271 272
        loss = self.get_attr("loss")
        no_grad_set = self.get_attr("no_grad_set")
273
        no_recompute_segments = self.get_attr("no_recompute_segments")
274 275
        self._dist_context = self.get_attr("dist_context")

276
        # 0. get op_path which is related to loss
277
        main_block = main_program.global_block()
278
        op_path = _find_op_path(main_program, loss, no_grad_set)
279

280
        # 1. build recompute state
281
        rc_state = RecomputeState(main_block, op_path)
282
        if not rc_state.is_recompute():
283 284 285
            return

        # 2. get the segments to be recomputed
286
        rc_state.modify_forward_desc_for_recompute(self._dist_context)
287 288 289
        rc_state.build_states()
        segments = rc_state.get_recompute_segments(no_recompute_segments)
        if segments == []:
290 291
            return

292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
        for i, (idx1, idx2) in enumerate(segments):
            logging.info(
                "recompute segment[{}/{}]".format(i + 1, len(segments))
            )
            logging.info(
                "segment start op: [{}]: [{}] [{}]".format(
                    rc_state.ops[idx1].type,
                    rc_state.ops[idx1].input_arg_names,
                    rc_state.ops[idx1].output_arg_names,
                )
            )
            logging.info(
                "segment end op: [{}]: [{}] [{}]".format(
                    rc_state.ops[idx2 - 1].type,
                    rc_state.ops[idx2 - 1].input_arg_names,
                    rc_state.ops[idx2 - 1].output_arg_names,
                )
            )

311
        # 3. get vars that should be hold in memory
312 313 314
        vars_should_be_hold = []
        for segment in segments:
            vars_should_be_hold.extend(
315 316
                rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
            )
317
        cross_vars = set(vars_should_be_hold) - set(rc_state.checkpoints)
318 319 320
        logging.info(
            "found [{}] vars which cross recompute segment: [{}],"
            "better checkpoints might be set to reduce those vars".format(
321 322 323
                len(cross_vars), cross_vars
            )
        )
324
        vars_should_be_hold.extend(rc_state.reserved_vars)
325
        vars_should_be_hold.extend(rc_state.get_input_nodes())
326 327 328
        vars_should_be_hold = list(
            set(vars_should_be_hold) | set(rc_state.checkpoints)
        )
329

330 331 332
        # 4. get the fwd ops desc to be recomputed.
        var_name_dict = {}  # varname --> varname.subprog_XXX
        ckpt_ops_dict = {}  # ckpt_op_id --> segment_descs
333 334
        buffer_block = main_block.program._create_block()
        for i, segment in enumerate(segments[::-1]):
335
            fwd_ops = op_path[segment[0] : segment[1]]
336 337 338
            var_suffix = ".subprog_%d" % i
            for op in fwd_ops:
                input_and_output_names = []
339 340 341
                input_and_output_names.extend(op.input_arg_names)
                input_and_output_names.extend(op.output_arg_names)

342 343 344
                cur_op_dist_attr = (
                    self._dist_context.get_op_dist_attr_for_program(op)
                )
345
                assert cur_op_dist_attr is not None
346

347
                for name in input_and_output_names:
348 349 350 351
                    if (
                        main_block.var(name).persistable
                        or name in vars_should_be_hold
                    ):
352 353 354
                        continue
                    if name not in var_name_dict:
                        ref_process_mesh = cur_op_dist_attr.process_mesh
355
                        if name in op.input_arg_names:
356 357 358
                            ref_dims_mapping = (
                                cur_op_dist_attr.get_input_dims_mapping(name)
                            )
359
                        else:
360 361 362
                            ref_dims_mapping = (
                                cur_op_dist_attr.get_output_dims_mapping(name)
                            )
363

364 365 366 367 368 369 370 371 372 373
                        # record recomputed var's old_name and new_name (old_name.subprog_XXX)
                        # create new var with new name
                        var_name_dict[name] = name + var_suffix
                        ref_var = main_block.var(name)
                        rc_var = main_block.create_var(
                            name=var_name_dict[name],
                            shape=ref_var.shape,
                            dtype=ref_var.dtype,
                            type=ref_var.type,
                            persistable=ref_var.persistable,
374 375
                            stop_gradient=ref_var.stop_gradient,
                        )
376
                        # set new recomputed var's dist attr
377 378 379 380 381 382
                        set_var_dist_attr(
                            self._dist_context,
                            rc_var,
                            ref_dims_mapping,
                            ref_process_mesh,
                        )
383
            # get recomputed segment's descs
384 385 386 387
            segment_descs = _add_needed_descs_to_block(
                fwd_ops,
                buffer_block,
                main_block,
388
                vars_should_be_hold,
389 390
                self._dist_context,
            )
391 392 393 394 395
            # rename recomputed ops' input and output var name
            for key in var_name_dict:
                _rename_arg_(segment_descs, key, var_name_dict[key])

            # NOTE: one forward op could be correspond to multiple xxx_grad op.
396
            # When traversing all grad_ops in reverse, need to set a flag to indicate
397 398
            # whether the ckpt and its segment_descs can be used.
            ckpt_op = op_path[segment[1] - 1]
399
            ckpt_ops_dict[ckpt_op.desc.original_id()] = [True, segment_descs]
400

401
        # 5. insert recomputed fwd ops into backward parse
402 403 404 405 406 407 408 409 410 411 412
        ops = main_block.ops
        loss_op = get_loss_op(main_block)
        loss_op_idx = _find_op_index(main_block, loss_op)
        dist_op_context = self._dist_context.dist_op_context
        assert loss_op_idx != -1
        # Traversing all grad_ops in reverse, and if the fwd op corresponding to reverse op is checkpoints,
        # segments ops should be inserted.
        for i in range(len(ops) - 1, loss_op_idx, -1):
            grad_op = ops[i]
            # remove some attrs of dropout_grad op's desc
            if grad_op.type == "dropout_grad":
413 414
                grad_op._remove_attr("fix_seed")
                grad_op._remove_attr("seed")
415

416 417 418 419 420 421
            input_and_output_names = []
            input_and_output_names.extend(grad_op.input_arg_names)
            input_and_output_names.extend(grad_op.output_arg_names)

            for varname in var_name_dict:
                if varname not in input_and_output_names:
422
                    continue
423
                self.reset_op_dist_attr(grad_op, var_name_dict)
424
                _rename_arg_([grad_op.desc], varname, var_name_dict[varname])
425 426

            # insert recomputed ops
427 428 429
            original_id = grad_op.desc.original_id()
            if original_id in dist_op_context.grad_op_id_to_op_id:
                fwd_op_id = dist_op_context.grad_op_id_to_op_id[original_id]
430 431 432 433 434
                if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]:
                    idx = grad_op.idx
                    while idx - 1 >= 0 and ops[idx - 1].type == "sum":
                        idx -= 1
                    segment_descs = ckpt_ops_dict[fwd_op_id][1]
435
                    rc_op = None
436
                    for _, op_desc in reversed(list(enumerate(segment_descs))):
437 438 439
                        rc_op = main_block._insert_op_without_sync(
                            idx, type='nop'
                        )
440
                        rc_desc = rc_op.desc
441
                        rc_desc.copy_from(op_desc)
442
                        rc_desc.set_original_id(rc_desc.id())
443 444
                        # set recomputed ops' dist attr
                        fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
445 446
                            op_desc.original_id()
                        )
447
                        assert fwd_op_dist_attr is not None
448 449 450
                        self.set_op_dist_attr(
                            rc_op, fwd_op_dist_attr, var_name_dict
                        )
451 452

                    ckpt_ops_dict[fwd_op_id][0] = False
453
                    if rc_op:
J
JZ-LIANG 已提交
454 455 456 457 458 459 460 461 462 463 464
                        prior_op = main_block.ops[rc_op.idx - 1]
                        posterior_op = rc_op
                        prior_mesh = (
                            self._dist_context.get_op_dist_attr_for_program(
                                prior_op
                            ).process_mesh
                        )
                        posterior_mesh = (
                            self._dist_context.get_op_dist_attr_for_program(
                                posterior_op
                            ).process_mesh
465
                        )
J
JZ-LIANG 已提交
466 467 468 469 470 471 472 473 474
                        # NOTE if two recompute segements across two pipeline stages
                        # not need dependecies for it
                        if prior_mesh == posterior_mesh:
                            insert_dependencies_for_two_ops(
                                main_block,
                                idx,
                                prior_op,
                                posterior_op,
                                self._dist_context,
475
                                is_recompute=True,
J
JZ-LIANG 已提交
476
                                sync=False,
477
                                op_namescope="recompute_segment_dep",
J
JZ-LIANG 已提交
478
                            )
479
        main_program._sync_with_cpp()
480 481 482 483

    def reset_op_dist_attr(self, op, var_name_dict):
        op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
        assert op_dist_attr is not None
484
        for input in op.input_arg_names:
485 486
            if input in var_name_dict.keys():
                in_dist_attr = op_dist_attr.get_input_dist_attr(input)
487 488 489
                op_dist_attr.set_input_dist_attr(
                    var_name_dict[input], in_dist_attr
                )
490
        for output in op.output_arg_names:
491 492
            if output in var_name_dict.keys():
                out_dist_attr = op_dist_attr.get_output_dist_attr(output)
493 494 495
                op_dist_attr.set_output_dist_attr(
                    var_name_dict[output], out_dist_attr
                )
496 497

    def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
498
        new_dist_attr = OperatorDistAttr()
499 500
        new_dist_attr.is_recompute = True
        new_dist_attr.impl_idx = old_dist_attr.impl_idx
Z
zhaoyingli 已提交
501
        new_dist_attr.impl_type = old_dist_attr.impl_type
502 503 504 505
        new_dist_attr.process_mesh = old_dist_attr.process_mesh
        for input in old_dist_attr.inputs_dist_attrs.keys():
            if input in var_name_dict.keys():
                in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
506 507 508
                new_dist_attr.set_input_dist_attr(
                    var_name_dict[input], in_dist_attr
                )
509 510 511 512 513 514
            else:
                in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
                new_dist_attr.set_input_dist_attr(input, in_dist_attr)
        for output in old_dist_attr.outputs_dist_attrs.keys():
            if output in var_name_dict.keys():
                out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
515 516 517
                new_dist_attr.set_output_dist_attr(
                    var_name_dict[output], out_dist_attr
                )
518 519 520 521
            else:
                out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
                new_dist_attr.set_output_dist_attr(output, out_dist_attr)
        self._dist_context.set_op_dist_attr_for_program(op, new_dist_attr)