parallelizer_v2.py 13.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
#   Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
16
import logging
17
import time
18

19
from paddle.distributed.passes import new_pass
20 21
from paddle.static import append_backward, program_guard
from paddle.utils import unique_name
22

23
from ..utils.log_utils import get_logger
24
from .partitioner import Partitioner
25
from .process_group import get_world_process_group
26 27
from .reshard import Resharder
from .utils import set_grad_var_shape
28 29 30 31 32 33 34


class Parallelizer:
    def __init__(self, mode, completer, dist_context):
        self._mode = mode
        self._completer = completer
        self._dist_context = dist_context
35
        assert self._dist_context._is_initialized
36 37
        self._pass_context = self._dist_context.pass_context
        self._strategy = self._dist_context.strategy
38
        self._logger = get_logger(logging.INFO)
39 40 41 42 43

    def parallel_all(self):
        world_process_group = get_world_process_group()
        all_ranks = world_process_group.ranks
        for rank in all_ranks:
44
            # self._dist_context._backup(serial=True, dist=True)
45
            self.parallel(rank)
46
            # self._dist_context._restore(serial=True, dist=True)
47 48 49 50 51 52 53

    def parallel(self, rank):
        serial_main_program = self._dist_context.serial_main_program
        serial_startup_program = self._dist_context.serial_startup_program
        serial_optimizer = self._dist_context.serial_optimizer
        if self._mode == "train" and serial_optimizer:
            # Generate backward
54
            serial_loss = self._dist_context.serial_loss
55 56 57
            params_grads = self._generate_backward(
                serial_main_program, serial_startup_program, serial_loss
            )
58
            # Apply pre optimization passes
59
            time0 = time.time()
60 61 62 63 64 65 66 67 68 69 70
            (
                serial_main_program,
                serial_startup_program,
                params_grads,
            ) = self._apply_pre_optimization(
                serial_main_program,
                serial_startup_program,
                serial_loss,
                serial_optimizer,
                params_grads,
            )
Z
zhaoyingli 已提交
71
            self._logger.debug(
72 73 74 75
                "within parallel apply_pre_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
76
            # Do logical partition
77
            time0 = time.time()
78
            partitioner = Partitioner(self._dist_context, rank)
79 80 81 82 83 84 85
            (
                dist_main_prog,
                dist_startup_prog,
                dist_params_grads,
            ) = partitioner.partition(
                serial_main_program, serial_startup_program, params_grads
            )
Z
zhaoyingli 已提交
86
            self._logger.debug(
87
                "within parallel partitioner time: {}, mode {}".format(
88 89 90
                    time.time() - time0, self._mode
                )
            )
91
            # Generate optimizer
92
            time0 = time.time()
93 94 95 96 97 98
            self._generate_optimizer(
                dist_main_prog,
                dist_startup_prog,
                serial_optimizer,
                dist_params_grads,
            )
Z
zhaoyingli 已提交
99
            self._logger.debug(
100
                "within parallel optimizer time: {}, mode {}".format(
101 102 103
                    time.time() - time0, self._mode
                )
            )
104
            # Do reshard process
105
            time0 = time.time()
106
            set_grad_var_shape(dist_main_prog, self._dist_context)
107 108 109 110 111 112 113
            resharder = Resharder(
                dist_main_prog,
                dist_startup_prog,
                rank,
                self._dist_context,
                dist_params_grads,
            )
114
            resharder.reshard()
Z
zhaoyingli 已提交
115
            self._logger.debug(
116
                "within parallel reshard time: {}, mode {}".format(
117 118 119
                    time.time() - time0, self._mode
                )
            )
120
            # Apply post optimization passes
121
            time0 = time.time()
122 123 124
            self._apply_post_optimization(
                dist_main_prog, dist_startup_prog, rank, dist_params_grads
            )
Z
zhaoyingli 已提交
125
            self._logger.debug(
126 127 128 129
                "within parallel apply_post_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
130 131
        else:
            # Apply pre optimization passes
132
            time0 = time.time()
133 134 135 136 137 138
            (
                serial_main_program,
                serial_startup_program,
                params_grads,
            ) = self._apply_pre_optimization(
                serial_main_program, serial_startup_program, None, None, []
139
            )
Z
zhaoyingli 已提交
140
            self._logger.debug(
141 142 143 144
                "within parallel apply_pre_optimization time: {}, mode {}".format(
                    time.time() - time0, self._mode
                )
            )
145
            # Do logical partition
146
            time0 = time.time()
147
            partitioner = Partitioner(self._dist_context, rank)
148 149 150 151 152 153 154
            (
                dist_main_prog,
                dist_startup_prog,
                dist_params_grads,
            ) = partitioner.partition(
                serial_main_program, serial_startup_program, []
            )
155
            # Do reshard process
Z
zhaoyingli 已提交
156
            self._logger.debug(
157
                "within parallel partitioner time: {}, mode {}".format(
158 159 160
                    time.time() - time0, self._mode
                )
            )
161
            time0 = time.time()
162 163 164 165 166 167 168 169
            resharder = Resharder(
                dist_main_prog,
                dist_startup_prog,
                rank,
                self._dist_context,
                [],
                1,
            )
170
            resharder.reshard()
Z
zhaoyingli 已提交
171
            self._logger.debug(
172
                "within parallel reshard time: {}, mode {}".format(
173 174 175
                    time.time() - time0, self._mode
                )
            )
176 177 178 179 180 181 182 183 184 185 186 187
        # Clone program for test
        if self._mode != 'train':
            dist_main_prog = dist_main_prog.clone(for_test=True)
            dist_startup_prog = dist_startup_prog.clone(for_test=True)

        # Store the distributed programs for further usages
        self._dist_context.dist_main_programs[rank] = dist_main_prog
        self._dist_context.dist_startup_programs[rank] = dist_startup_prog

    def _generate_backward(self, main_program, startup_program, loss):
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
188 189
                loss, distop_context=self._dist_context.dist_op_context
            )
190 191 192 193
        self._completer.complete_backward_annotation(main_program)
        self._dist_context.block_state.parse_backward_blocks(main_program)
        return params_grads

194 195 196
    def _generate_optimizer(
        self, main_program, startup_program, optimizer, params_grads
    ):
197 198
        # NOTE: `apply_gradients` will add an Accumulator for a parameter only once,
        # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied.
199
        optimizer = copy.deepcopy(optimizer)
Z
zhaoyingli 已提交
200
        self._dist_context._serial_optimizer = optimizer
201
        with program_guard(main_program, startup_program):
202 203
            with unique_name.guard("opt_"):
                optimizer_ops = optimizer.apply_gradients(params_grads)
204 205 206
        self._completer.complete_update_annotation(main_program)
        return optimizer_ops

207 208 209
    def _apply_pre_optimization(
        self, main_program, startup_program, loss, optimizer, params_grads
    ):
210 211
        if self._strategy is None:
            return
212

Z
zhaoyingli 已提交
213
        # apply amp pass on train/eval/predict
214
        if self._strategy.amp.enable:
215
            config = copy.deepcopy(self._strategy.amp.to_dict())
216 217 218
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
219 220
            config["input_data"] = (
                self._dist_context.serial_feed_vars["inputs"]
221
                + self._dist_context.serial_feed_vars["labels"]
222
            )
X
xu98bin 已提交
223 224 225 226 227 228 229 230
            if config["enable_bf16"]:
                auto_parallel_bf16_pass = new_pass("auto_parallel_bf16", config)
                auto_parallel_bf16_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
                loss = auto_parallel_bf16_pass.get_loss()

            elif config["use_pure_fp16"]:
231 232
                config["base_opt"] = optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
233 234 235
                auto_parallel_fp16_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
236
                loss = auto_parallel_fp16_pass.get_loss()
X
xu98bin 已提交
237

238 239
            else:
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
240 241 242
                auto_parallel_amp_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
243
                loss = auto_parallel_amp_pass.get_loss()
244

245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
        # apply quantization pass
        # The pass can be applied when mode must be 'train'
        if self._strategy.qat.enable:
            config = copy.deepcopy(self._strategy.qat.to_dict())
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["mode"] = self._mode
            config["loss"] = loss
            auto_parallel_quantization_pass = new_pass(
                "auto_parallel_quantization", config
            )
            auto_parallel_quantization_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
            main_program = self._pass_context.get_attr("main_program")
            startup_program = self._pass_context.get_attr("startup_program")
            params_grads = self._pass_context.get_attr("params_grads")
            loss = self._pass_context.get_attr("loss")

264
        # apply recompute pass
265
        # recompute is then train-only optimization
266 267
        if self._mode == "train" and self._strategy.recompute.enable:
            config = copy.deepcopy(self._strategy.recompute.to_dict())
268 269 270
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = None
            config["loss"] = loss
271 272 273 274 275 276
            auto_parallel_recompute_pass = new_pass(
                "auto_parallel_recompute", config
            )
            auto_parallel_recompute_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
277

278 279
        return main_program, startup_program, params_grads

280 281 282
    def _apply_post_optimization(
        self, main_program, startup_program, rank, params_grads
    ):
283 284
        if self._strategy is None:
            return
285 286 287 288 289

        # data parallel optimization
        config = {}
        config["dist_context"] = self._dist_context
        config["global_rank"] = rank
290
        config["use_sharding"] = self._strategy.sharding.enable
291 292 293
        dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
        dp_pass.apply([main_program], [startup_program], self._pass_context)

294 295
        if self._strategy.sharding.enable:
            config = copy.deepcopy(self._strategy.sharding.to_dict())
296 297 298
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
299 300 301 302 303 304
            auto_parallel_sharding_pass = new_pass(
                "auto_parallel_sharding", config
            )
            auto_parallel_sharding_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
305
            params_grads = self._pass_context.get_attr("params_grads")
306

307 308
        # GradClip is train-only optimization
        if self._mode == "train":
309
            config = copy.deepcopy(self._strategy.sharding.to_dict())
310 311 312
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["rank_id"] = rank
313 314 315 316 317 318
            auto_parallel_clip_pass = new_pass(
                "auto_parallel_grad_clip", config
            )
            auto_parallel_clip_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
319

320 321 322 323 324 325 326 327 328 329
            # deps for newexe
            config = {}
            config["dist_context"] = self._dist_context
            APSED_pass = new_pass(
                "auto_parallel_supplement_explicit_dependencies", config
            )
            APSED_pass.apply(
                [main_program], [startup_program], self._pass_context
            )

330
        # gradient_merge is then train-only optimization
331 332
        if self._mode == "train" and self._strategy.gradient_merge.enable:
            config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
333 334 335
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
336 337 338 339 340
                "auto_parallel_gradient_merge_pass", config
            )
            auto_parallel_gradient_merge_pass.apply(
                [main_program], [startup_program], self._pass_context
            )