parallelizer_v2.py 14.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16 17
import time
import logging
18 19 20

from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
21
from paddle.fluid.framework import unique_name
22 23 24 25
from paddle.distributed.passes import new_pass

from .reshard import Resharder
from .partitioner import Partitioner
26 27
from .utils import set_grad_var_shape
from .process_group import get_world_process_group
J
JZ-LIANG 已提交
28
from .random import init_auto_parallel_rng
29
from ..utils.log_utils import get_logger
30 31 32 33 34 35 36


class Parallelizer:
    def __init__(self, mode, completer, dist_context):
        self._mode = mode
        self._completer = completer
        self._dist_context = dist_context
37
        assert self._dist_context._is_initialized
38 39
        self._pass_context = self._dist_context.pass_context
        self._strategy = self._dist_context.strategy
40
        self._logger = get_logger(logging.INFO)
41 42 43 44 45

    def parallel_all(self):
        world_process_group = get_world_process_group()
        all_ranks = world_process_group.ranks
        for rank in all_ranks:
46
            # self._dist_context._backup(serial=True, dist=True)
47
            self.parallel(rank)
48
            # self._dist_context._restore(serial=True, dist=True)
49 50 51 52 53 54 55

    def parallel(self, rank):
        serial_main_program = self._dist_context.serial_main_program
        serial_startup_program = self._dist_context.serial_startup_program
        serial_optimizer = self._dist_context.serial_optimizer
        if self._mode == "train" and serial_optimizer:
            # Generate backward
56
            serial_loss = self._dist_context.serial_loss
57 58 59
            params_grads = self._generate_backward(
                serial_main_program, serial_startup_program, serial_loss
            )
60
            # Apply pre optimization passes
61
            time0 = time.time()
62 63 64 65 66 67 68 69 70 71 72
            (
                serial_main_program,
                serial_startup_program,
                params_grads,
            ) = self._apply_pre_optimization(
                serial_main_program,
                serial_startup_program,
                serial_loss,
                serial_optimizer,
                params_grads,
            )
73
            self._logger.debug(
74 75 76 77
                "within parallel apply_pre_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
78
            # Do logical partition
79
            time0 = time.time()
80
            partitioner = Partitioner(self._dist_context, rank)
81 82 83 84 85 86 87
            (
                dist_main_prog,
                dist_startup_prog,
                dist_params_grads,
            ) = partitioner.partition(
                serial_main_program, serial_startup_program, params_grads
            )
J
JZ-LIANG 已提交
88 89 90

            init_auto_parallel_rng()

91
            self._logger.debug(
92
                "within parallel partitioner time: {}, mode {}".format(
93 94 95
                    time.time() - time0, self._mode
                )
            )
96
            # Generate optimizer
97
            time0 = time.time()
98 99 100 101 102 103
            self._generate_optimizer(
                dist_main_prog,
                dist_startup_prog,
                serial_optimizer,
                dist_params_grads,
            )
104
            self._logger.debug(
105
                "within parallel optimizer time: {}, mode {}".format(
106 107 108
                    time.time() - time0, self._mode
                )
            )
109
            # Do reshard process
110
            time0 = time.time()
111
            set_grad_var_shape(dist_main_prog, self._dist_context)
112 113 114 115 116 117 118
            resharder = Resharder(
                dist_main_prog,
                dist_startup_prog,
                rank,
                self._dist_context,
                dist_params_grads,
            )
119
            resharder.reshard()
120
            self._logger.debug(
121
                "within parallel reshard time: {}, mode {}".format(
122 123 124
                    time.time() - time0, self._mode
                )
            )
125
            # Apply post optimization passes
126
            time0 = time.time()
127 128 129
            self._apply_post_optimization(
                dist_main_prog, dist_startup_prog, rank, dist_params_grads
            )
130
            self._logger.debug(
131 132 133 134
                "within parallel apply_post_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
135 136
        else:
            # Apply pre optimization passes
137
            time0 = time.time()
138 139 140
            self._apply_pre_optimization(
                serial_main_program, serial_startup_program, None, None, None
            )
141
            self._logger.debug(
142 143 144 145
                "within parallel apply_pre_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
146
            # Do logical partition
147
            time0 = time.time()
148
            partitioner = Partitioner(self._dist_context, rank)
149 150 151 152 153 154 155
            (
                dist_main_prog,
                dist_startup_prog,
                dist_params_grads,
            ) = partitioner.partition(
                serial_main_program, serial_startup_program, []
            )
156
            # Do reshard process
157
            self._logger.debug(
158
                "within parallel partitioner time: {}, mode {}".format(
159 160 161
                    time.time() - time0, self._mode
                )
            )
162 163 164 165 166
            micro_bsz = (
                1
                if not self._strategy.pipeline.enable
                else self._strategy.pipeline.micro_batch_size
            )
167
            time0 = time.time()
168 169 170 171 172 173
            resharder = Resharder(
                dist_main_prog,
                dist_startup_prog,
                rank,
                self._dist_context,
                [],
174
                micro_bsz,
175
            )
176
            resharder.reshard()
177
            self._logger.debug(
178
                "within parallel reshard time: {}, mode {}".format(
179 180 181 182 183 184 185 186 187 188 189 190 191
                    time.time() - time0, self._mode
                )
            )
            # Apply post optimization passes
            time0 = time.time()
            self._apply_post_optimization(
                dist_main_prog, dist_startup_prog, rank, dist_params_grads
            )
            self._logger.debug(
                "within parallel apply_post_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
192 193
        # Clone program for test
        if self._mode != 'train':
194
            pipeline_opt = dist_main_prog._pipeline_opt
195 196
            dist_main_prog = dist_main_prog.clone(for_test=True)
            dist_startup_prog = dist_startup_prog.clone(for_test=True)
197
            dist_main_prog._pipeline_opt = pipeline_opt
198 199 200 201 202 203 204 205

        # Store the distributed programs for further usages
        self._dist_context.dist_main_programs[rank] = dist_main_prog
        self._dist_context.dist_startup_programs[rank] = dist_startup_prog

    def _generate_backward(self, main_program, startup_program, loss):
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
206 207
                loss, distop_context=self._dist_context.dist_op_context
            )
208 209 210 211
        self._completer.complete_backward_annotation(main_program)
        self._dist_context.block_state.parse_backward_blocks(main_program)
        return params_grads

212 213 214
    def _generate_optimizer(
        self, main_program, startup_program, optimizer, params_grads
    ):
215 216
        # NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
        # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
217
        optimizer = copy.deepcopy(optimizer)
218
        self._dist_context._serial_optimizer = optimizer
219
        with program_guard(main_program, startup_program):
220 221
            with unique_name.guard("opt_"):
                optimizer_ops = optimizer.apply_gradients(params_grads)
222 223 224
        self._completer.complete_update_annotation(main_program)
        return optimizer_ops

225 226 227
    def _apply_pre_optimization(
        self, main_program, startup_program, loss, optimizer, params_grads
    ):
228 229
        if self._strategy is None:
            return
230 231 232

        # apply quantization pass
        # The pass can be applied when mode must be 'train'
233 234
        if self._mode == 'train' and self._strategy.qat.enable:
            config = copy.deepcopy(self._strategy.qat.to_dict())
235 236 237
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_quantization_pass = new_pass(
238 239 240 241 242
                "auto_parallel_quantization", config
            )
            auto_parallel_quantization_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
243 244 245 246
            main_program = self._pass_context.get_attr("main_program")
            startup_program = self._pass_context.get_attr("startup_program")
            params_grads = self._pass_context.get_attr("params_grads")

247
        # apply amp pass on train/eval/predict
248
        if self._strategy.amp.enable:
249
            config = copy.deepcopy(self._strategy.amp.to_dict())
250 251 252
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
253 254
            config["input_data"] = (
                self._dist_context.serial_feed_vars["inputs"]
255
                + self._dist_context.serial_feed_vars["labels"]
256
            )
J
JZ-LIANG 已提交
257 258 259 260 261 262 263 264 265 266 267 268
            self._logger.info(
                "Applying AMP-{}-{} ...".format(
                    config["dtype"], config['level']
                ),
            )
            if config['level'] == "o1":
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
                auto_parallel_amp_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
                loss = auto_parallel_amp_pass.get_loss()
            elif config['level'] in ['o2', 'o3']:
269 270
                config["base_opt"] = optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
271 272 273
                auto_parallel_fp16_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
J
JZ-LIANG 已提交
274
                loss = auto_parallel_fp16_pass.get_loss()
275
            else:
J
JZ-LIANG 已提交
276
                raise ValueError("AMP level should be one of o1, o2, o3")
277 278

        # apply recompute pass
279
        # recompute is then train-only optimization
280 281
        if self._mode == "train" and self._strategy.recompute.enable:
            config = copy.deepcopy(self._strategy.recompute.to_dict())
282 283 284
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = None
            config["loss"] = loss
285 286 287 288 289 290
            auto_parallel_recompute_pass = new_pass(
                "auto_parallel_recompute", config
            )
            auto_parallel_recompute_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
291

292 293
        return main_program, startup_program, params_grads

294 295 296
    def _apply_post_optimization(
        self, main_program, startup_program, rank, params_grads
    ):
297 298
        if self._strategy is None:
            return
299

300 301 302 303 304 305 306 307 308
        if self._strategy.dp_optimization.enable:
            config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
            config["dist_context"] = self._dist_context
            config["global_rank"] = rank
            config["use_sharding"] = self._strategy.sharding.enable
            dp_pass = new_pass(
                "auto_parallel_data_parallel_optimization", config
            )
            dp_pass.apply([main_program], [startup_program], self._pass_context)
309

310 311
        if self._strategy.sharding.enable:
            config = copy.deepcopy(self._strategy.sharding.to_dict())
312 313 314
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
315 316 317 318 319 320
            auto_parallel_sharding_pass = new_pass(
                "auto_parallel_sharding", config
            )
            auto_parallel_sharding_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
321
            params_grads = self._pass_context.get_attr("params_grads")
322

323 324
        # GradClip is train-only optimization
        if self._mode == "train":
325
            config = copy.deepcopy(self._strategy.sharding.to_dict())
326 327 328
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["rank_id"] = rank
329 330 331 332 333 334 335 336 337 338 339 340 341
            auto_parallel_clip_pass = new_pass(
                "auto_parallel_grad_clip", config
            )
            auto_parallel_clip_pass.apply(
                [main_program], [startup_program], self._pass_context
            )

        if self._strategy.pipeline.enable:
            self._strategy.gradient_merge.enable = True
            self._strategy.gradient_merge.k_steps = (
                self._strategy.pipeline.accumulate_steps
            )
            self._strategy.gradient_merge.avg = True
342 343

        # gradient_merge is then train-only optimization
344 345
        if self._mode == "train" and self._strategy.gradient_merge.enable:
            config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
346 347 348
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
                "auto_parallel_gradient_merge_pass", config
            )
            auto_parallel_gradient_merge_pass.apply(
                [main_program], [startup_program], self._pass_context
            )

        if self._strategy.pipeline.enable:
            config = copy.deepcopy(self._strategy.pipeline.to_dict())
            config["dist_context"] = self._dist_context
            auto_parallel_pipeline_pass = new_pass(
                "auto_parallel_pipeline", config
            )
            auto_parallel_pipeline_pass.apply(
                [main_program], [startup_program], self._pass_context
            )