parallelizer_v2.py 12.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16 17
import time
import logging
18 19
from collections import defaultdict

20
import paddle
21 22
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
23
from paddle.fluid.framework import _non_static_mode, unique_name
24 25 26 27 28 29 30 31 32
from paddle.distributed.passes import new_pass

from .reshard import Resharder
from .partitioner import Partitioner
from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver
from .dist_loader import NonIterableGeneratorLoader
from .utils import make_data_unshard, set_grad_var_shape
from .utils import print_program_with_dist_attr, to_list
33
from .utils import get_logger
34 35 36 37 38
from .process_group import get_all_process_groups, get_world_process_group
from .dist_context import DistributedContext, get_default_distributed_context


class Parallelizer:
39

40 41 42 43
    def __init__(self, mode, completer, dist_context):
        self._mode = mode
        self._completer = completer
        self._dist_context = dist_context
44
        assert self._dist_context._is_initialized
45 46
        self._pass_context = self._dist_context.pass_context
        self._strategy = self._dist_context.strategy
47
        self._logger = get_logger(logging.INFO)
48 49 50 51 52

    def parallel_all(self):
        world_process_group = get_world_process_group()
        all_ranks = world_process_group.ranks
        for rank in all_ranks:
53
            # self._dist_context._backup(serial=True, dist=True)
54
            self.parallel(rank)
55
            # self._dist_context._restore(serial=True, dist=True)
56 57 58 59 60 61 62

    def parallel(self, rank):
        serial_main_program = self._dist_context.serial_main_program
        serial_startup_program = self._dist_context.serial_startup_program
        serial_optimizer = self._dist_context.serial_optimizer
        if self._mode == "train" and serial_optimizer:
            # Generate backward
63
            serial_loss = self._dist_context.serial_loss
64 65 66
            params_grads = self._generate_backward(serial_main_program,
                                                   serial_startup_program,
                                                   serial_loss)
67
            # Apply pre optimization passes
68
            time0 = time.time()
69 70 71
            serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization(
                serial_main_program, serial_startup_program, serial_loss,
                serial_optimizer, params_grads)
72 73 74
            self._logger.info(
                "within parallel apply_pre_optimization time: {}, mode {}".
                format(time.time() - time0, self._mode))
75
            # Do logical partition
76
            time0 = time.time()
77 78 79
            partitioner = Partitioner(self._dist_context, rank)
            dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
                serial_main_program, serial_startup_program, params_grads)
80 81 82
            self._logger.info(
                "within parallel partitioner time: {}, mode {}".format(
                    time.time() - time0, self._mode))
83
            # Generate optimizer
84
            time0 = time.time()
85 86
            self._generate_optimizer(dist_main_prog, dist_startup_prog,
                                     serial_optimizer, dist_params_grads)
87 88 89
            self._logger.info(
                "within parallel optimizer time: {}, mode {}".format(
                    time.time() - time0, self._mode))
90
            # Do reshard process
91
            time0 = time.time()
92 93 94 95
            set_grad_var_shape(dist_main_prog, self._dist_context)
            resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
                                  self._dist_context, dist_params_grads)
            resharder.reshard()
96 97 98
            self._logger.info(
                "within parallel reshard time: {}, mode {}".format(
                    time.time() - time0, self._mode))
99
            # Apply post optimization passes
100
            time0 = time.time()
101 102
            self._apply_post_optimization(dist_main_prog, dist_startup_prog,
                                          rank, dist_params_grads)
103 104 105
            self._logger.info(
                "within parallel apply_post_optimization time: {}, mode {}".
                format(time.time() - time0, self._mode))
106 107
        else:
            # Apply pre optimization passes
108 109 110 111 112 113 114
            time0 = time.time()
            self._apply_pre_optimization(serial_main_program,
                                         serial_startup_program, None, None,
                                         None)
            self._logger.info(
                "within parallel apply_pre_optimization time: {}, mode {}".
                format(time.time() - time0, self._mode))
115
            # Do logical partition
116
            time0 = time.time()
117 118 119 120
            partitioner = Partitioner(self._dist_context, rank)
            dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
                serial_main_program, serial_startup_program, [])
            # Do reshard process
121 122 123 124
            self._logger.info(
                "within parallel partitioner time: {}, mode {}".format(
                    time.time() - time0, self._mode))
            time0 = time.time()
125 126 127
            resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
                                  self._dist_context, [], 1)
            resharder.reshard()
128 129 130
            self._logger.info(
                "within parallel reshard time: {}, mode {}".format(
                    time.time() - time0, self._mode))
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        # Clone program for test
        if self._mode != 'train':
            dist_main_prog = dist_main_prog.clone(for_test=True)
            dist_startup_prog = dist_startup_prog.clone(for_test=True)

        # Store the distributed programs for further usages
        self._dist_context.dist_main_programs[rank] = dist_main_prog
        self._dist_context.dist_startup_programs[rank] = dist_startup_prog

    def _generate_backward(self, main_program, startup_program, loss):
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
                loss, distop_context=self._dist_context.dist_op_context)
        self._completer.complete_backward_annotation(main_program)
        self._dist_context.block_state.parse_backward_blocks(main_program)
        return params_grads

    def _generate_optimizer(self, main_program, startup_program, optimizer,
                            params_grads):
150 151
        # NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
        # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
152
        optimizer = copy.deepcopy(optimizer)
153
        self._dist_context._lr_optimizer = optimizer
154
        with program_guard(main_program, startup_program):
155 156
            with unique_name.guard("opt_"):
                optimizer_ops = optimizer.apply_gradients(params_grads)
157 158 159 160 161 162 163
        self._completer.complete_update_annotation(main_program)
        return optimizer_ops

    def _apply_pre_optimization(self, main_program, startup_program, loss,
                                optimizer, params_grads):
        if self._strategy is None:
            return
164 165 166

        # apply quantization pass
        # The pass can be applied when mode must be 'train'
167 168
        if self._mode == 'train' and self._strategy.qat.enable:
            config = copy.deepcopy(self._strategy.qat.to_dict())
169 170 171 172 173 174 175 176 177 178 179
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_quantization_pass = new_pass(
                "auto_parallel_quantization", config)
            auto_parallel_quantization_pass.apply([main_program],
                                                  [startup_program],
                                                  self._pass_context)
            main_program = self._pass_context.get_attr("main_program")
            startup_program = self._pass_context.get_attr("startup_program")
            params_grads = self._pass_context.get_attr("params_grads")

180
        # apply amp pass
181 182
        # FIXME we disenable amp for eval since it has a little bug with
        # eval program and which will be fixed in future
183
        if self._strategy.amp.enable:
184
            config = copy.deepcopy(self._strategy.amp.to_dict())
185 186 187 188 189 190 191 192
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
            config["input_data"] = self._dist_context.serial_feed_vars["inputs"] \
                + self._dist_context.serial_feed_vars["labels"]
            if config["use_pure_fp16"]:
                config["base_opt"] = optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
193 194
                auto_parallel_fp16_pass.apply([main_program], [startup_program],
                                              self._pass_context)
195 196 197 198 199 200
            else:
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
                auto_parallel_amp_pass.apply([main_program], [startup_program],
                                             self._pass_context)

        # apply recompute pass
201
        # recompute is then train-only optimization
202 203
        if self._mode == "train" and self._strategy.recompute.enable:
            config = copy.deepcopy(self._strategy.recompute.to_dict())
204 205 206 207 208
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = None
            config["loss"] = loss
            auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
                                                    config)
209 210
            auto_parallel_recompute_pass.apply([main_program],
                                               [startup_program],
211
                                               self._pass_context)
212

213 214
        return main_program, startup_program, params_grads

215 216 217 218
    def _apply_post_optimization(self, main_program, startup_program, rank,
                                 params_grads):
        if self._strategy is None:
            return
219 220 221 222 223

        # data parallel optimization
        config = {}
        config["dist_context"] = self._dist_context
        config["global_rank"] = rank
224
        config["use_sharding"] = self._strategy.sharding.enable
225 226 227
        dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
        dp_pass.apply([main_program], [startup_program], self._pass_context)

228 229
        if self._strategy.sharding.enable:
            config = copy.deepcopy(self._strategy.sharding.to_dict())
230 231 232 233 234
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
            auto_parallel_sharding_pass = new_pass("auto_parallel_sharding",
                                                   config)
235
            auto_parallel_sharding_pass.apply([main_program], [startup_program],
236
                                              self._pass_context)
237
            params_grads = self._pass_context.get_attr("params_grads")
238

239 240
        # GradClip is train-only optimization
        if self._mode == "train":
241
            config = copy.deepcopy(self._strategy.sharding.to_dict())
242 243 244 245 246 247 248 249 250
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["rank_id"] = rank
            auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip",
                                               config)
            auto_parallel_clip_pass.apply([main_program], [startup_program],
                                          self._pass_context)

        # gradient_merge is then train-only optimization
251 252
        if self._mode == "train" and self._strategy.gradient_merge.enable:
            config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
253 254 255 256
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
                "auto_parallel_gradient_merge_pass", config)
257 258
            auto_parallel_gradient_merge_pass.apply([main_program],
                                                    [startup_program],
259
                                                    self._pass_context)