parallelizer_v2.py 8.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from collections import defaultdict

from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.distributed.passes import new_pass

from .reshard import Resharder
from .partitioner import Partitioner
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context


class Parallelizer:
34

35 36 37 38
    def __init__(self, mode, completer, dist_context):
        self._mode = mode
        self._completer = completer
        self._dist_context = dist_context
39
        assert self._dist_context._is_initialized
40 41 42 43 44 45 46
        self._pass_context = self._dist_context.pass_context
        self._strategy = self._dist_context.strategy

    def parallel_all(self):
        world_process_group = get_world_process_group()
        all_ranks = world_process_group.ranks
        for rank in all_ranks:
47
            # self._dist_context._backup(serial=True, dist=True)
48
            self.parallel(rank)
49
            # self._dist_context._restore(serial=True, dist=True)
50 51 52 53 54 55 56

    def parallel(self, rank):
        serial_main_program = self._dist_context.serial_main_program
        serial_startup_program = self._dist_context.serial_startup_program
        serial_optimizer = self._dist_context.serial_optimizer
        if self._mode == "train" and serial_optimizer:
            # Generate backward
57
            serial_loss = self._dist_context.serial_loss
58 59 60
            params_grads = self._generate_backward(serial_main_program,
                                                   serial_startup_program,
                                                   serial_loss)
61 62 63 64
            # Apply pre optimization passes
            self._apply_pre_optimization(serial_main_program,
                                         serial_startup_program, serial_loss,
                                         serial_optimizer, params_grads)
65

66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
            # Do logical partition
            partitioner = Partitioner(self._dist_context, rank)
            dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
                serial_main_program, serial_startup_program, params_grads)
            # Generate optimizer
            self._generate_optimizer(dist_main_prog, dist_startup_prog,
                                     serial_optimizer, dist_params_grads)
            # Do reshard process
            set_grad_var_shape(dist_main_prog, self._dist_context)
            resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
                                  self._dist_context, dist_params_grads)
            resharder.reshard()
            # Apply post optimization passes
            self._apply_post_optimization(dist_main_prog, dist_startup_prog,
                                          rank, dist_params_grads)
        else:
            # Apply pre optimization passes
83 84 85
            self._apply_pre_optimization(serial_main_program,
                                         serial_startup_program, None, None,
                                         None)
86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
            # Do logical partition
            partitioner = Partitioner(self._dist_context, rank)
            dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
                serial_main_program, serial_startup_program, [])
            # Do reshard process
            resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
                                  self._dist_context, [], 1)
            resharder.reshard()
        # Clone program for test
        if self._mode != 'train':
            dist_main_prog = dist_main_prog.clone(for_test=True)
            dist_startup_prog = dist_startup_prog.clone(for_test=True)

        # Store the distributed programs for further usages
        self._dist_context.dist_main_programs[rank] = dist_main_prog
        self._dist_context.dist_startup_programs[rank] = dist_startup_prog

    def _generate_backward(self, main_program, startup_program, loss):
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
                loss, distop_context=self._dist_context.dist_op_context)
        self._completer.complete_backward_annotation(main_program)
        self._dist_context.block_state.parse_backward_blocks(main_program)
        return params_grads

    def _generate_optimizer(self, main_program, startup_program, optimizer,
                            params_grads):
        with program_guard(main_program, startup_program):
            optimizer_ops = copy.deepcopy(optimizer).apply_gradients(
                params_grads)
        self._completer.complete_update_annotation(main_program)
        return optimizer_ops

    def _apply_pre_optimization(self, main_program, startup_program, loss,
                                optimizer, params_grads):
        if self._strategy is None:
            return
        # apply amp pass
        if self._strategy.amp:
            config = copy.deepcopy(self._strategy.amp_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
            config["input_data"] = self._dist_context.serial_feed_vars["inputs"] \
                + self._dist_context.serial_feed_vars["labels"]
            if config["use_pure_fp16"]:
                config["base_opt"] = optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
134 135
                auto_parallel_fp16_pass.apply([main_program], [startup_program],
                                              self._pass_context)
136 137 138 139 140 141 142 143 144 145 146 147 148
            else:
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
                auto_parallel_amp_pass.apply([main_program], [startup_program],
                                             self._pass_context)

        # apply recompute pass
        if self._strategy.recompute:
            config = copy.deepcopy(self._strategy.recompute_configs)
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = None
            config["loss"] = loss
            auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
                                                    config)
149 150 151
            auto_parallel_recompute_pass.apply([main_program],
                                               [startup_program],
                                               self._dist_context)
152 153 154 155 156 157 158 159 160 161 162 163

    def _apply_post_optimization(self, main_program, startup_program, rank,
                                 params_grads):
        if self._strategy is None:
            return
        if self._strategy.sharding:
            config = copy.deepcopy(self._strategy.sharding_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
            auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
                                                   config)
164 165
            auto_parallel_sharding_pass.apply([main_program], [startup_program],
                                              self._dist_context)
166 167 168 169 170 171 172

        if self._strategy.gradient_merge:
            config = copy.deepcopy(self._strategy.gradient_merge_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
                "auto_parallel_gradient_merge_pass", config)
173 174 175
            auto_parallel_gradient_merge_pass.apply([main_program],
                                                    [startup_program],
                                                    self._dist_context)