parallelizer_v2.py 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16
import logging
17
import time
18

Y
yuehuayingxueluo 已提交
19
from paddle.distributed.passes import PassManager, new_pass
20 21
from paddle.static import append_backward, program_guard
from paddle.utils import unique_name
22

23
from ..utils.log_utils import get_logger
24
from .partitioner import Partitioner
25
from .process_group import get_world_process_group
26 27
from .reshard import Resharder
from .utils import set_grad_var_shape
28 29 30 31 32 33 34


class Parallelizer:
    def __init__(self, mode, completer, dist_context):
        self._mode = mode
        self._completer = completer
        self._dist_context = dist_context
35
        assert self._dist_context._is_initialized
36 37
        self._pass_context = self._dist_context.pass_context
        self._strategy = self._dist_context.strategy
38
        self._logger = get_logger(logging.INFO)
39 40 41 42 43

    def parallel_all(self):
        world_process_group = get_world_process_group()
        all_ranks = world_process_group.ranks
        for rank in all_ranks:
44
            # self._dist_context._backup(serial=True, dist=True)
45
            self.parallel(rank)
46
            # self._dist_context._restore(serial=True, dist=True)
47 48 49 50 51 52 53

    def parallel(self, rank):
        serial_main_program = self._dist_context.serial_main_program
        serial_startup_program = self._dist_context.serial_startup_program
        serial_optimizer = self._dist_context.serial_optimizer
        if self._mode == "train" and serial_optimizer:
            # Generate backward
54
            serial_loss = self._dist_context.serial_loss
55 56 57
            params_grads = self._generate_backward(
                serial_main_program, serial_startup_program, serial_loss
            )
58
            # Apply pre optimization passes
59
            time0 = time.time()
60 61 62 63 64 65 66 67 68 69 70
            (
                serial_main_program,
                serial_startup_program,
                params_grads,
            ) = self._apply_pre_optimization(
                serial_main_program,
                serial_startup_program,
                serial_loss,
                serial_optimizer,
                params_grads,
            )
Z
zhaoyingli 已提交
71
            self._logger.debug(
72 73 74 75
                "within parallel apply_pre_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
76
            # Do logical partition
77
            time0 = time.time()
78
            partitioner = Partitioner(self._dist_context, rank)
79 80 81 82 83 84 85
            (
                dist_main_prog,
                dist_startup_prog,
                dist_params_grads,
            ) = partitioner.partition(
                serial_main_program, serial_startup_program, params_grads
            )
Z
zhaoyingli 已提交
86
            self._logger.debug(
87
                "within parallel partitioner time: {}, mode {}".format(
88 89 90
                    time.time() - time0, self._mode
                )
            )
91
            # Generate optimizer
92
            time0 = time.time()
93 94 95 96 97 98
            self._generate_optimizer(
                dist_main_prog,
                dist_startup_prog,
                serial_optimizer,
                dist_params_grads,
            )
Z
zhaoyingli 已提交
99
            self._logger.debug(
100
                "within parallel optimizer time: {}, mode {}".format(
101 102 103
                    time.time() - time0, self._mode
                )
            )
104
            # Do reshard process
105
            time0 = time.time()
106
            set_grad_var_shape(dist_main_prog, self._dist_context)
107 108 109 110 111 112 113
            resharder = Resharder(
                dist_main_prog,
                dist_startup_prog,
                rank,
                self._dist_context,
                dist_params_grads,
            )
114
            resharder.reshard()
Z
zhaoyingli 已提交
115
            self._logger.debug(
116
                "within parallel reshard time: {}, mode {}".format(
117 118 119
                    time.time() - time0, self._mode
                )
            )
120
            # Apply post optimization passes
121
            time0 = time.time()
122 123 124
            self._apply_post_optimization(
                dist_main_prog, dist_startup_prog, rank, dist_params_grads
            )
Z
zhaoyingli 已提交
125
            self._logger.debug(
126 127 128 129
                "within parallel apply_post_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
130 131
        else:
            # Apply pre optimization passes
132
            time0 = time.time()
133 134 135 136 137 138
            (
                serial_main_program,
                serial_startup_program,
                params_grads,
            ) = self._apply_pre_optimization(
                serial_main_program, serial_startup_program, None, None, []
139
            )
Z
zhaoyingli 已提交
140
            self._logger.debug(
141 142 143 144
                "within parallel apply_pre_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
145
            # Do logical partition
146
            time0 = time.time()
147
            partitioner = Partitioner(self._dist_context, rank)
148 149 150 151 152 153 154
            (
                dist_main_prog,
                dist_startup_prog,
                dist_params_grads,
            ) = partitioner.partition(
                serial_main_program, serial_startup_program, []
            )
155
            # Do reshard process
Z
zhaoyingli 已提交
156
            self._logger.debug(
157
                "within parallel partitioner time: {}, mode {}".format(
158 159 160
                    time.time() - time0, self._mode
                )
            )
161
            time0 = time.time()
162 163 164 165 166 167 168 169
            resharder = Resharder(
                dist_main_prog,
                dist_startup_prog,
                rank,
                self._dist_context,
                [],
                1,
            )
170
            resharder.reshard()
Z
zhaoyingli 已提交
171
            self._logger.debug(
172
                "within parallel reshard time: {}, mode {}".format(
173 174 175
                    time.time() - time0, self._mode
                )
            )
176 177 178 179 180 181 182 183 184 185 186 187
        # Clone program for test
        if self._mode != 'train':
            dist_main_prog = dist_main_prog.clone(for_test=True)
            dist_startup_prog = dist_startup_prog.clone(for_test=True)

        # Store the distributed programs for further usages
        self._dist_context.dist_main_programs[rank] = dist_main_prog
        self._dist_context.dist_startup_programs[rank] = dist_startup_prog

    def _generate_backward(self, main_program, startup_program, loss):
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
188 189
                loss, distop_context=self._dist_context.dist_op_context
            )
190 191 192 193
        self._completer.complete_backward_annotation(main_program)
        self._dist_context.block_state.parse_backward_blocks(main_program)
        return params_grads

194 195 196
    def _generate_optimizer(
        self, main_program, startup_program, optimizer, params_grads
    ):
197 198
        # NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
        # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
199
        optimizer = copy.deepcopy(optimizer)
Z
zhaoyingli 已提交
200
        self._dist_context._serial_optimizer = optimizer
201
        with program_guard(main_program, startup_program):
202 203
            with unique_name.guard("opt_"):
                optimizer_ops = optimizer.apply_gradients(params_grads)
204 205 206
        self._completer.complete_update_annotation(main_program)
        return optimizer_ops

207 208 209
    def _apply_pre_optimization(
        self, main_program, startup_program, loss, optimizer, params_grads
    ):
210 211
        if self._strategy is None:
            return
212

Z
zhaoyingli 已提交
213
        # apply amp pass on train/eval/predict
214
        if self._strategy.amp.enable:
215
            config = copy.deepcopy(self._strategy.amp.to_dict())
216 217 218
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
219 220
            config["input_data"] = (
                self._dist_context.serial_feed_vars["inputs"]
221
                + self._dist_context.serial_feed_vars["labels"]
222
            )
223 224 225 226 227 228 229 230
            self._logger.info(
                "Applying AMP-{}-{} ...".format(
                    config["dtype"], config['level']
                ),
            )
            if config['level'] == "o1":
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
                auto_parallel_amp_pass.apply(
X
xu98bin 已提交
231 232
                    [main_program], [startup_program], self._pass_context
                )
233 234
                loss = auto_parallel_amp_pass.get_loss()
            elif config['level'] in ['o2', 'o3']:
235 236
                config["base_opt"] = optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
237 238 239
                auto_parallel_fp16_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
240
                loss = auto_parallel_fp16_pass.get_loss()
241
            else:
242
                raise ValueError("AMP level should be one of o1, o2, o3")
243

244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
        # apply quantization pass
        # The pass can be applied when mode must be 'train'
        if self._strategy.qat.enable:
            config = copy.deepcopy(self._strategy.qat.to_dict())
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["mode"] = self._mode
            config["loss"] = loss
            auto_parallel_quantization_pass = new_pass(
                "auto_parallel_quantization", config
            )
            auto_parallel_quantization_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
            main_program = self._pass_context.get_attr("main_program")
            startup_program = self._pass_context.get_attr("startup_program")
            params_grads = self._pass_context.get_attr("params_grads")
            loss = self._pass_context.get_attr("loss")

263
        # apply recompute pass
264
        # recompute is then train-only optimization
265 266
        if self._mode == "train" and self._strategy.recompute.enable:
            config = copy.deepcopy(self._strategy.recompute.to_dict())
267 268 269
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = None
            config["loss"] = loss
270 271 272 273 274 275
            auto_parallel_recompute_pass = new_pass(
                "auto_parallel_recompute", config
            )
            auto_parallel_recompute_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
276

277 278
        return main_program, startup_program, params_grads

279 280 281
    def _apply_post_optimization(
        self, main_program, startup_program, rank, params_grads
    ):
282 283
        if self._strategy is None:
            return
284 285 286 287 288

        # data parallel optimization
        config = {}
        config["dist_context"] = self._dist_context
        config["global_rank"] = rank
289
        config["use_sharding"] = self._strategy.sharding.enable
290 291 292
        dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
        dp_pass.apply([main_program], [startup_program], self._pass_context)

293 294
        if self._strategy.sharding.enable:
            config = copy.deepcopy(self._strategy.sharding.to_dict())
295 296 297
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
298 299 300 301 302 303
            auto_parallel_sharding_pass = new_pass(
                "auto_parallel_sharding", config
            )
            auto_parallel_sharding_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
304
            params_grads = self._pass_context.get_attr("params_grads")
305

306 307
        # GradClip is train-only optimization
        if self._mode == "train":
308
            config = copy.deepcopy(self._strategy.sharding.to_dict())
309 310 311
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["rank_id"] = rank
312 313 314 315 316 317
            auto_parallel_clip_pass = new_pass(
                "auto_parallel_grad_clip", config
            )
            auto_parallel_clip_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
318

319 320 321 322 323 324 325 326 327 328
            # deps for newexe
            config = {}
            config["dist_context"] = self._dist_context
            APSED_pass = new_pass(
                "auto_parallel_supplement_explicit_dependencies", config
            )
            APSED_pass.apply(
                [main_program], [startup_program], self._pass_context
            )

329
        # gradient_merge is then train-only optimization
330 331
        if self._mode == "train" and self._strategy.gradient_merge.enable:
            config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
332 333 334
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
335 336 337 338 339
                "auto_parallel_gradient_merge_pass", config
            )
            auto_parallel_gradient_merge_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
Y
yuehuayingxueluo 已提交
340 341 342 343 344 345 346 347

        if self._mode == "train" and self._strategy.fused_passes.enable:
            if len(self._strategy.fused_passes.fused_passes_list) > 0:
                new_pass_list = []
                for op in self._strategy.fused_passes.fused_passes_list:
                    new_pass_list.append(new_pass(op))
                pass_manager = PassManager(new_pass_list)
                pass_manager.apply([main_program], [startup_program])