parallelizer.py 21.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import json
Z
zhaoyingli 已提交
17
import logging
18 19
import os
import pathlib
20
import pickle
21 22 23
import shlex
import subprocess
import sys
24
import time
25

26
import paddle
27 28
from paddle.distributed.passes import PassContext, new_pass
from paddle.distributed.utils.log_utils import get_logger
29 30
from paddle.framework import core
from paddle.static import append_backward, program_guard
31

32
from .cluster import Cluster
33 34
from .completion import Completer
from .dist_context import DistributedContext, set_default_distributed_context
35 36
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
37 38
from .mapper import mapping
from .partitioner import Partitioner
39
from .planner import Planner
40 41 42 43 44 45 46 47 48
from .process_group import (
    ProcessGroup,
    _g_process_group_map,
    get_all_process_groups,
    get_process_group,
    get_world_process_group,
)
from .reshard import Resharder
from .utils import SerialProgramInfo, make_data_unshard, set_grad_var_shape
Z
zhaoyingli 已提交
49 50

_logger = get_logger(logging.INFO)
51 52 53 54 55 56 57


class AutoParallelizer:
    """
    AutoParallelizer is the main controller class to do the auto parallel process.
    And the auto parallel process will be triggered in the wrapped parallelize function.
    To facilitate the auto parallelization, it will contain information about program, cluster and the
C
chenxujun 已提交
58 59 60
    related context. In this basic version, the program information will be retrieved from
    Fleet object, and the cluster information can be retrieved in the new created Cluster object,
    and the context information can be retrieved in the new created DistributedContext.
61 62 63 64 65 66
    """

    def __init__(self, fleet):
        self._fleet = fleet
        self._optimizer = self._fleet.user_defined_optimizer
        self._dist_strategy = self._fleet._user_defined_strategy
67
        self._dist_context = DistributedContext()
68 69 70 71 72 73 74 75 76 77 78 79
        self._cluster = None
        self._cluster_topo_path = os.getenv("PADDLE_CLUSTER_TOPO_PATH", None)
        if self._cluster_topo_path is not None:
            self._cluster = Cluster()
            self._cluster.build_from_file(self._cluster_topo_path)
        # Prepare information for auto mapping
        self._rank_mapping_path = os.getenv("PADDLE_RANK_MAPPING_PATH", None)
        enable_auto_mapping_env = os.getenv("PADDLE_ENABLE_AUTO_MAPPING", None)
        if enable_auto_mapping_env is None:
            self._enable_auto_mapping = False
        else:
            self._enable_auto_mapping = True
80 81
        self._pass_context = PassContext()

82
        self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
83 84 85 86 87 88
        self._need_rank_mapping = (
            True
            if self._need_rank_mapping
            and self._need_rank_mapping.lower() == 'true'
            else False
        )
89
        # self._pass_context = None
90

91 92 93 94 95 96 97 98 99 100
    def _remove_distributed_attrs(self, main_program):
        suffix = core.kAutoParallelSuffix()
        # distributed attributes for variable have been removed
        # in previous process.
        for block in main_program.blocks:
            for op in block.ops:
                for attr_name in op.attr_names:
                    if suffix in attr_name:
                        op._remove_attr(attr_name)

101 102 103
    def _apply_pre_optimization_passes(
        self, main_program, startup_program, loss, params_grads, no_grad_set
    ):
J
JZ-LIANG 已提交
104
        # apply amp pass
105
        if self._dist_strategy.amp:
J
JZ-LIANG 已提交
106 107 108 109
            config = copy.deepcopy(self._dist_strategy.amp_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
110 111 112
            if config["use_pure_fp16"]:
                config["base_opt"] = self._optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
113 114 115
                auto_parallel_fp16_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
116
                loss = auto_parallel_fp16_pass.get_loss()
117 118
            else:
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
119 120 121
                auto_parallel_amp_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
122
                loss = auto_parallel_amp_pass.get_loss()
123

J
JZ-LIANG 已提交
124
        # apply recompute pass
125
        if self._dist_strategy.recompute:
126 127 128 129
            config = copy.deepcopy(self._dist_strategy.recompute_configs)
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = copy.deepcopy(no_grad_set)
            config["loss"] = loss
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
            auto_parallel_recompute_pass = new_pass(
                "auto_parallel_recompute", config
            )
            auto_parallel_recompute_pass.apply(
                [main_program], [startup_program], self._pass_context
            )

    def _generate_backward(
        self,
        main_program,
        startup_program,
        loss,
        parameter_list,
        no_grad_set,
        callbacks,
    ):
146

J
JZ-LIANG 已提交
147 148 149 150 151 152
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
                loss,
                parameter_list,
                no_grad_set,
                callbacks,
153 154
                distop_context=self._dist_context.dist_op_context,
            )
155 156
        self._completer = Completer(self._dist_context)
        self._completer.complete_backward_annotation(main_program)
157
        self._dist_context.block_state.parse_backward_blocks(main_program)
158 159 160 161
        return params_grads

    def _apply_optimize(self, main_program, startup_program, params_grads):

162
        optimizer = copy.deepcopy(self._optimizer)
J
JZ-LIANG 已提交
163
        with program_guard(main_program, startup_program):
164
            optimize_ops = optimizer.apply_gradients(params_grads)
165

Z
zhaoyingli 已提交
166
        self._dist_context._serial_optimizer = optimizer
167
        # update completion
168 169
        self._completer = Completer(self._dist_context)
        self._completer.complete_update_annotation(main_program)
170 171 172

        return optimize_ops

173 174 175
    def _apply_post_optimization_passes(
        self, main_program, startup_program, rank, params_grads
    ):
J
JZ-LIANG 已提交
176 177 178 179 180 181

        if self._dist_strategy.sharding:
            config = copy.deepcopy(self._dist_strategy.sharding_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
182 183 184 185 186 187
            auto_parallel_sharding_pass = new_pass(
                "auto_parallel_sharding", config
            )
            auto_parallel_sharding_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
188 189 190 191 192 193 194
            params_grads = self._pass_context.get_attr("params_grads")

        config = copy.deepcopy(self._dist_strategy.sharding_configs)
        config["dist_context"] = self._dist_context
        config["params_grads"] = params_grads
        config["rank_id"] = rank
        auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
195 196 197
        auto_parallel_clip_pass.apply(
            [main_program], [startup_program], self._pass_context
        )
J
JZ-LIANG 已提交
198

199 200 201 202 203
        if self._dist_strategy.gradient_merge:
            config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
204 205 206 207 208
                "auto_parallel_gradient_merge_pass", config
            )
            auto_parallel_gradient_merge_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
209

210 211
    def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
        completed_main_program = None
212 213 214
        serial_main_program = self._main_program.clone()
        serial_startup_program = self._startup_program.clone()
        serial_loss = serial_main_program.global_block().var(self._loss.name)
215

216
        # generating serial
217 218 219 220
        if dist_context is None:
            # Annotation completion
            self._dist_context = DistributedContext()
            _logger.info("Start annotation dist attr.")
221
            self._completer = Completer(self._dist_context)
222 223 224
            completed_main_program = (
                self._completer.complete_forward_annotation(serial_main_program)
            )
225
        else:
226
            completed_main_program = serial_main_program
227 228
            self._dist_context = copy.deepcopy(dist_context)

229 230 231
        # parse forward sub block
        self._dist_context.block_state.parse_forward_blocks(serial_main_program)

232 233
        # serial backward pass
        params_grads = self._generate_backward(
234 235 236 237 238 239 240
            completed_main_program,
            serial_startup_program,
            serial_loss,
            self._parameter_list,
            self._no_grad_set,
            self._callbacks,
        )
241

J
JZ-LIANG 已提交
242
        # serial forward pass
243 244 245 246 247 248 249
        self._apply_pre_optimization_passes(
            completed_main_program,
            serial_startup_program,
            serial_loss,
            params_grads,
            self._no_grad_set,
        )
250
        # Logical partition
251
        partitioner = Partitioner(self._dist_context, rank)
252 253 254 255 256 257 258
        (
            dist_main_prog,
            dist_startup_prog,
            dist_params_grads,
        ) = partitioner.partition(
            completed_main_program, serial_startup_program, params_grads
        )
259 260 261

        # TODO refactor the placement of optimizer
        # generate optimize program
262 263 264
        dist_optimize_ops = self._apply_optimize(
            dist_main_prog, dist_startup_prog, dist_params_grads
        )
265

266
        set_grad_var_shape(dist_main_prog, self._dist_context)
267

268
        make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
269

270 271 272 273 274 275 276
        resharder = Resharder(
            dist_main_prog,
            dist_startup_prog,
            rank,
            self._dist_context,
            dist_params_grads,
        )
277
        resharder.reshard()
278

279 280 281
        self._apply_post_optimization_passes(
            dist_main_prog, dist_startup_prog, rank, dist_params_grads
        )
282 283 284 285 286
        g_process_group_map = None
        if not relaunch_phase:
            g_process_group_map = copy.deepcopy(_g_process_group_map)
            _g_process_group_map.clear()
            _g_process_group_map[0] = ProcessGroup(0, [])
287
            for process_mesh in self._dist_context._process_meshes:
288
                _g_process_group_map[0].add_ranks(process_mesh.process_ids)
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
        return (
            dist_optimize_ops,
            dist_params_grads,
            dist_startup_prog,
            dist_main_prog,
            g_process_group_map,
        )

    def parallelize(
        self,
        loss,
        startup_program,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
305
        assert startup_program is not None
306 307 308 309 310
        self._loss = loss
        self._startup_program = startup_program
        self._main_program = loss.block.program
        self._parameter_list = parameter_list
        self._no_grad_set = no_grad_set
311
        self._callbacks = callbacks
312 313 314

        if self._enable_auto_mapping and self._need_rank_mapping:
            # Do the mapping pass before parallelization
315 316 317
            assert (
                self._cluster is not None
            ), "The cluster must not be none when using auto mapping."
318
            dist_programs = {}
J
JZ-LIANG 已提交
319
            world_process_group = get_world_process_group()
320 321 322 323
            dist_context = None
            # auto search
            if self._dist_strategy.auto_search:
                logging.info("Start searching dist attr.")
324 325 326 327 328 329 330 331 332 333 334 335
                serial_program_info = SerialProgramInfo(
                    self._main_program,
                    self._startup_program,
                    self._loss,
                    self._optimizer,
                    self._cluster,
                )
                planner = Planner(
                    serial_program_info,
                    self,
                    algorithm_config={"name": "mcmc", "max_search_times": 5},
                )
336 337 338 339 340 341 342 343
                dist_context, _ = planner.search()
                logging.info("End searching dist attr.")

            # serialize the dist context by planner
            if dist_context is not None:
                logging.info("Start serialize searched dist attr")
                cwd = pathlib.Path().resolve()
                searched_dist_context_path = os.path.join(
344 345
                    cwd, f"searched_dist_context_{time.time()}.pkl"
                )
346 347 348 349 350
                saved_dist_context = {}
                ops_dist_attr = {}
                tensors_dist_attr = {}
                for key, dist_op in dist_context._dist_ops_for_program.items():
                    ops_dist_attr[key] = dist_op.dist_attr
351 352 353 354
                for (
                    key,
                    dist_tensor,
                ) in dist_context._dist_tensors_for_program.items():
355 356 357 358
                    tensors_dist_attr[key] = dist_tensor.dist_attr
                saved_dist_context["ops_dist_attr"] = ops_dist_attr
                saved_dist_context["tensors_dist_attr"] = tensors_dist_attr
                saved_dist_context[
359 360 361 362 363
                    "process_meshes"
                ] = dist_context._process_meshes
                with open(
                    searched_dist_context_path, "wb"
                ) as dist_context_file:
364 365
                    pickle.dump(saved_dist_context, dist_context_file)
                    os.environ[
366 367
                        'PADDLE_SEARCHED_DIST_CONTEXT_PATH'
                    ] = searched_dist_context_path
368 369 370 371
                    logging.info(
                        f"End serialize searched dist attr to {searched_dist_context_path}"
                    )

372
            for rank in world_process_group.ranks:
373 374 375 376 377 378 379
                (
                    dist_optimize_ops,
                    dist_params_grads,
                    dist_startup_prog,
                    dist_main_prog,
                    g_process_group_map,
                ) = self._get_dist_program(rank, dist_context)
380
                dist_programs[rank] = [dist_main_prog, g_process_group_map]
381 382 383 384

            # Do the mapping between the distributed program graph and the cluster graph
            rank_mapping_dict = mapping(dist_programs, self._cluster)
            rank_mapping = list(rank_mapping_dict.values())
385

386 387 388 389 390
            # Relaunch the training by using the rank mapping file
            with open(self._rank_mapping_path, "w") as rank_mapping_file:
                json.dump(rank_mapping, rank_mapping_file)

            enable_elastic = os.getenv("PADDLE_ENABLE_ELASTIC")
391 392 393 394 395
            enable_elastic = (
                True
                if enable_elastic and enable_elastic.lower() == 'true'
                else False
            )
396 397
            if enable_elastic:
                print("Auto mapping finished, now do elastic re-launch")
398 399 400
                sys.exit(
                    paddle.distributed.fleet.elastic.manager.ELASTIC_AUTO_PARALLEL_EXIT_CODE
                )
401 402 403

            original_cmd_args = os.getenv("PADDLE_ORIGINAL_CMD_ARGS")
            rank_mapping_args = " ".join(
404 405
                ["--rank_mapping_path", self._rank_mapping_path]
            )
406 407 408 409
            if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
                coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
            else:
                coverage_args = []
410 411 412 413 414 415 416 417 418 419 420 421
            new_cmd_args = (
                "-m paddle.distributed.fleet.launch"
                + " "
                + rank_mapping_args
                + " "
                + original_cmd_args
            )
            new_cmd = (
                [sys.executable, "-u"]
                + coverage_args
                + shlex.split(new_cmd_args)
            )
422 423
            new_process = subprocess.Popen(new_cmd)
            new_process.wait()
424 425 426
            assert (
                new_process.returncode == 0
            ), "Launch failed with rank mapping"
427 428 429 430 431
            print("Successfully do the second launch for auto mapping!")
            sys.exit(0)
        else:
            # Parallelization after the mapping pass
            rank = paddle.distributed.get_rank()
432 433
            dist_context = None
            searched_dist_context_path = os.getenv(
434 435
                "PADDLE_SEARCHED_DIST_CONTEXT_PATH", None
            )
436
            if searched_dist_context_path is not None:
437 438 439
                with open(
                    searched_dist_context_path, "rb"
                ) as dist_context_file:
440 441 442 443
                    saved_dist_context = pickle.load(dist_context_file)
                    dist_context = DistributedContext()
                    for op in self._main_program.global_block().ops:
                        dist_attr = saved_dist_context["ops_dist_attr"][
444 445
                            op.desc.id()
                        ]
446 447 448 449 450 451
                        dist_op = DistributedOperator(op, dist_attr)
                        dist_context.add_dist_op_for_program(dist_op)

                    vars = self._main_program.global_block().vars
                    for var in vars.values():
                        dist_attr = saved_dist_context["tensors_dist_attr"][
452 453
                            var.desc.id()
                        ]
454 455 456 457
                        dist_tensor = DistributedTensor(var, dist_attr)
                        dist_context.add_dist_tensor_for_program(dist_tensor)

                    dist_context._process_meshes = saved_dist_context[
458 459
                        "process_meshes"
                    ]
460 461 462 463 464 465 466 467

            else:
                if self._dist_strategy.auto_search:
                    serial_program_info = SerialProgramInfo(
                        self._main_program,
                        self._startup_program,
                        self._loss,
                        self._optimizer,
468 469 470 471 472 473 474 475 476 477
                        cluster=self._cluster,
                    )
                    planner = Planner(
                        serial_program_info,
                        self,
                        algorithm_config={
                            "name": "mcmc",
                            "max_search_times": 5,
                        },
                    )
478 479 480 481 482 483
                    dist_context, _ = planner.search()

            # rebuild g_process_group
            if dist_context is not None:
                pg0 = get_process_group(0)
                for process_mesh in dist_context._process_meshes:
484
                    pg0.add_ranks(process_mesh.process_ids)
485 486 487 488 489 490 491
            (
                dist_optimize_ops,
                dist_params_grads,
                dist_startup_prog,
                dist_main_prog,
                _,
            ) = self._get_dist_program(rank, dist_context, relaunch_phase=True)
492

493 494 495 496 497 498 499 500 501
            # NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner
            if self._dist_strategy.auto_search:
                is_pipeline = False
                for op in dist_main_prog.global_block().ops:
                    if op.type == "send_v2" or op.type == "recv_v2":
                        is_pipeline = True
                        break
                if is_pipeline:
                    with paddle.static.program_guard(dist_main_prog):
502
                        paddle.distributed.barrier(get_process_group(0))
503

504 505 506 507
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
            for process_group in all_process_groups:
508 509 510 511 512 513 514
                if len(_g_process_group_map) > 0:
                    tmp = paddle.to_tensor([1], dtype="int32")
                    paddle.distributed.all_reduce(
                        tmp, sync_op=True, group=_g_process_group_map[0]
                    )
                    paddle.device.cuda.synchronize()

515 516 517
                if rank not in process_group.ranks:
                    continue
                process_group.instantiate()
C
caozhou 已提交
518

519 520
            # Copy distributed info to the default context
            set_default_distributed_context(self._dist_context)
Z
zhaoyingli 已提交
521

522 523 524
            # The last step: remove all distributed attributes to be compatible
            # with inference.
            self._remove_distributed_attrs(dist_main_prog)
525

526 527 528 529 530 531
            return (
                dist_optimize_ops,
                dist_params_grads,
                dist_startup_prog,
                dist_main_prog,
            )
532 533 534 535 536 537

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
538 539 540 541 542 543 544
            if (
                k == "_main_program"
                or k == "_startup_program"
                or k == "_dist_context"
                or k == "_fleet"
                or k == "_loss"
            ):
545 546 547 548
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result