parallelizer.py 20.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import json
Z
zhaoyingli 已提交
17
import logging
18 19
import os
import pathlib
20
import pickle
21 22 23
import shlex
import subprocess
import sys
24
import time
25

26
import paddle
27
import paddle.fluid.core as core
28 29
from paddle.distributed.passes import PassContext, new_pass
from paddle.distributed.utils.log_utils import get_logger
30
from paddle.fluid import program_guard
Z
zhaoyingli 已提交
31
from paddle.fluid.backward import append_backward
32

33
from .cluster import Cluster
34 35
from .completion import Completer
from .dist_context import DistributedContext, set_default_distributed_context
36 37
from .dist_op import DistributedOperator
from .dist_tensor import DistributedTensor
38 39
from .mapper import mapping
from .partitioner import Partitioner
40
from .planner import Planner
41 42 43 44 45 46 47 48 49
from .process_group import (
    ProcessGroup,
    _g_process_group_map,
    get_all_process_groups,
    get_process_group,
    get_world_process_group,
)
from .reshard import Resharder
from .utils import SerialProgramInfo, make_data_unshard, set_grad_var_shape
Z
zhaoyingli 已提交
50 51

_logger = get_logger(logging.INFO)
52 53 54 55 56 57 58


class AutoParallelizer:
    """
    AutoParallelizer is the main controller class to do the auto parallel process.
    And the auto parallel process will be triggered in the wrapped parallelize function.
    To facilitate the auto parallelization, it will contain information about program, cluster and the
59
    related context. In this basic version, the program information will be retrevied from
60
    Fleet object, and the cluster information can be retrevied in the new created Cluster object,
61
    and the context information can be retrevied in the new created DistributedContext.
62 63 64 65 66 67
    """

    def __init__(self, fleet):
        self._fleet = fleet
        self._optimizer = self._fleet.user_defined_optimizer
        self._dist_strategy = self._fleet._user_defined_strategy
68
        self._dist_context = DistributedContext()
69 70 71 72 73 74 75 76 77 78 79 80
        self._cluster = None
        self._cluster_topo_path = os.getenv("PADDLE_CLUSTER_TOPO_PATH", None)
        if self._cluster_topo_path is not None:
            self._cluster = Cluster()
            self._cluster.build_from_file(self._cluster_topo_path)
        # Prepare information for auto mapping
        self._rank_mapping_path = os.getenv("PADDLE_RANK_MAPPING_PATH", None)
        enable_auto_mapping_env = os.getenv("PADDLE_ENABLE_AUTO_MAPPING", None)
        if enable_auto_mapping_env is None:
            self._enable_auto_mapping = False
        else:
            self._enable_auto_mapping = True
81 82
        self._pass_context = PassContext()

83
        self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
84 85 86 87 88 89
        self._need_rank_mapping = (
            True
            if self._need_rank_mapping
            and self._need_rank_mapping.lower() == 'true'
            else False
        )
90
        # self._pass_context = None
91

92 93 94 95 96 97 98 99 100 101
    def _remove_distributed_attrs(self, main_program):
        suffix = core.kAutoParallelSuffix()
        # distributed attributes for variable have been removed
        # in previous process.
        for block in main_program.blocks:
            for op in block.ops:
                for attr_name in op.attr_names:
                    if suffix in attr_name:
                        op._remove_attr(attr_name)

102 103 104
    def _apply_pre_optimization_passes(
        self, main_program, startup_program, loss, params_grads, no_grad_set
    ):
J
JZ-LIANG 已提交
105
        # apply amp pass
106
        if self._dist_strategy.amp:
J
JZ-LIANG 已提交
107 108 109 110
            config = copy.deepcopy(self._dist_strategy.amp_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["loss"] = loss
111 112 113
            if config["use_pure_fp16"]:
                config["base_opt"] = self._optimizer
                auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
114 115 116
                auto_parallel_fp16_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
117
                loss = auto_parallel_fp16_pass.get_loss()
118 119
            else:
                auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
120 121 122
                auto_parallel_amp_pass.apply(
                    [main_program], [startup_program], self._pass_context
                )
123
                loss = auto_parallel_amp_pass.get_loss()
124

J
JZ-LIANG 已提交
125
        # apply recompute pass
126
        if self._dist_strategy.recompute:
127 128 129 130
            config = copy.deepcopy(self._dist_strategy.recompute_configs)
            config["dist_context"] = self._dist_context
            config["no_grad_set"] = copy.deepcopy(no_grad_set)
            config["loss"] = loss
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
            auto_parallel_recompute_pass = new_pass(
                "auto_parallel_recompute", config
            )
            auto_parallel_recompute_pass.apply(
                [main_program], [startup_program], self._pass_context
            )

    def _generate_backward(
        self,
        main_program,
        startup_program,
        loss,
        parameter_list,
        no_grad_set,
        callbacks,
    ):
147

J
JZ-LIANG 已提交
148 149 150 151 152 153
        with program_guard(main_program, startup_program):
            params_grads = append_backward(
                loss,
                parameter_list,
                no_grad_set,
                callbacks,
154 155
                distop_context=self._dist_context.dist_op_context,
            )
156 157
        self._completer = Completer(self._dist_context)
        self._completer.complete_backward_annotation(main_program)
158
        self._dist_context.block_state.parse_backward_blocks(main_program)
159 160 161 162
        return params_grads

    def _apply_optimize(self, main_program, startup_program, params_grads):

163
        optimizer = copy.deepcopy(self._optimizer)
J
JZ-LIANG 已提交
164
        with program_guard(main_program, startup_program):
165
            optimize_ops = optimizer.apply_gradients(params_grads)
166

Z
zhaoyingli 已提交
167
        self._dist_context._serial_optimizer = optimizer
168
        # update completion
169 170
        self._completer = Completer(self._dist_context)
        self._completer.complete_update_annotation(main_program)
171 172 173

        return optimize_ops

174 175 176
    def _apply_post_optimization_passes(
        self, main_program, startup_program, rank, params_grads
    ):
J
JZ-LIANG 已提交
177 178 179 180 181 182

        if self._dist_strategy.sharding:
            config = copy.deepcopy(self._dist_strategy.sharding_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            config["global_rank"] = rank
183 184 185 186 187 188
            auto_parallel_sharding_pass = new_pass(
                "auto_parallel_sharding", config
            )
            auto_parallel_sharding_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
189 190 191 192 193 194 195
            params_grads = self._pass_context.get_attr("params_grads")

        config = copy.deepcopy(self._dist_strategy.sharding_configs)
        config["dist_context"] = self._dist_context
        config["params_grads"] = params_grads
        config["rank_id"] = rank
        auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
196 197 198
        auto_parallel_clip_pass.apply(
            [main_program], [startup_program], self._pass_context
        )
J
JZ-LIANG 已提交
199

200 201 202 203 204
        if self._dist_strategy.gradient_merge:
            config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
            config["dist_context"] = self._dist_context
            config["params_grads"] = params_grads
            auto_parallel_gradient_merge_pass = new_pass(
205 206 207 208 209
                "auto_parallel_gradient_merge_pass", config
            )
            auto_parallel_gradient_merge_pass.apply(
                [main_program], [startup_program], self._pass_context
            )
210

211 212
    def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False):
        completed_main_program = None
213 214 215
        serial_main_program = self._main_program.clone()
        serial_startup_program = self._startup_program.clone()
        serial_loss = serial_main_program.global_block().var(self._loss.name)
216

217
        # generating serial
218 219 220 221
        if dist_context is None:
            # Annotation completion
            self._dist_context = DistributedContext()
            _logger.info("Start annotation dist attr.")
222
            self._completer = Completer(self._dist_context)
223 224 225
            completed_main_program = (
                self._completer.complete_forward_annotation(serial_main_program)
            )
226
        else:
227
            completed_main_program = serial_main_program
228 229
            self._dist_context = copy.deepcopy(dist_context)

230 231 232
        # parse forward sub block
        self._dist_context.block_state.parse_forward_blocks(serial_main_program)

233 234
        # serial backward pass
        params_grads = self._generate_backward(
235 236 237 238 239 240 241
            completed_main_program,
            serial_startup_program,
            serial_loss,
            self._parameter_list,
            self._no_grad_set,
            self._callbacks,
        )
242

J
JZ-LIANG 已提交
243
        # serial forward pass
244 245 246 247 248 249 250
        self._apply_pre_optimization_passes(
            completed_main_program,
            serial_startup_program,
            serial_loss,
            params_grads,
            self._no_grad_set,
        )
251
        # Logical partition
252
        partitioner = Partitioner(self._dist_context, rank)
253 254 255 256 257 258 259
        (
            dist_main_prog,
            dist_startup_prog,
            dist_params_grads,
        ) = partitioner.partition(
            completed_main_program, serial_startup_program, params_grads
        )
260 261 262

        # TODO refactor the placement of optimizer
        # generate optimize program
263 264 265
        dist_optimize_ops = self._apply_optimize(
            dist_main_prog, dist_startup_prog, dist_params_grads
        )
266

267
        set_grad_var_shape(dist_main_prog, self._dist_context)
268

269
        make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context)
270

271 272 273 274 275 276 277
        resharder = Resharder(
            dist_main_prog,
            dist_startup_prog,
            rank,
            self._dist_context,
            dist_params_grads,
        )
278
        resharder.reshard()
279

280 281 282
        self._apply_post_optimization_passes(
            dist_main_prog, dist_startup_prog, rank, dist_params_grads
        )
283 284 285 286 287
        g_process_group_map = None
        if not relaunch_phase:
            g_process_group_map = copy.deepcopy(_g_process_group_map)
            _g_process_group_map.clear()
            _g_process_group_map[0] = ProcessGroup(0, [])
288
            for process_mesh in self._dist_context._process_meshes:
289
                _g_process_group_map[0].add_ranks(process_mesh.process_ids)
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305
        return (
            dist_optimize_ops,
            dist_params_grads,
            dist_startup_prog,
            dist_main_prog,
            g_process_group_map,
        )

    def parallelize(
        self,
        loss,
        startup_program,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    ):
306
        assert startup_program is not None
307 308 309 310 311
        self._loss = loss
        self._startup_program = startup_program
        self._main_program = loss.block.program
        self._parameter_list = parameter_list
        self._no_grad_set = no_grad_set
312
        self._callbacks = callbacks
313 314 315

        if self._enable_auto_mapping and self._need_rank_mapping:
            # Do the mapping pass before parallelization
316 317 318
            assert (
                self._cluster is not None
            ), "The cluster must not be none when using auto mapping."
319
            dist_programs = {}
J
JZ-LIANG 已提交
320
            world_process_group = get_world_process_group()
321 322 323 324
            dist_context = None
            # auto search
            if self._dist_strategy.auto_search:
                logging.info("Start searching dist attr.")
325 326 327 328 329 330 331 332 333 334 335 336
                serial_program_info = SerialProgramInfo(
                    self._main_program,
                    self._startup_program,
                    self._loss,
                    self._optimizer,
                    self._cluster,
                )
                planner = Planner(
                    serial_program_info,
                    self,
                    algorithm_config={"name": "mcmc", "max_search_times": 5},
                )
337 338 339 340 341 342 343 344
                dist_context, _ = planner.search()
                logging.info("End searching dist attr.")

            # serialize the dist context by planner
            if dist_context is not None:
                logging.info("Start serialize searched dist attr")
                cwd = pathlib.Path().resolve()
                searched_dist_context_path = os.path.join(
345 346
                    cwd, f"searched_dist_context_{time.time()}.pkl"
                )
347 348 349 350 351
                saved_dist_context = {}
                ops_dist_attr = {}
                tensors_dist_attr = {}
                for key, dist_op in dist_context._dist_ops_for_program.items():
                    ops_dist_attr[key] = dist_op.dist_attr
352 353 354 355
                for (
                    key,
                    dist_tensor,
                ) in dist_context._dist_tensors_for_program.items():
356 357 358 359
                    tensors_dist_attr[key] = dist_tensor.dist_attr
                saved_dist_context["ops_dist_attr"] = ops_dist_attr
                saved_dist_context["tensors_dist_attr"] = tensors_dist_attr
                saved_dist_context[
360 361 362 363 364
                    "process_meshes"
                ] = dist_context._process_meshes
                with open(
                    searched_dist_context_path, "wb"
                ) as dist_context_file:
365 366
                    pickle.dump(saved_dist_context, dist_context_file)
                    os.environ[
367 368
                        'PADDLE_SEARCHED_DIST_CONTEXT_PATH'
                    ] = searched_dist_context_path
369 370 371 372
                    logging.info(
                        f"End serialize searched dist attr to {searched_dist_context_path}"
                    )

373
            for rank in world_process_group.ranks:
374 375 376 377 378 379 380
                (
                    dist_optimize_ops,
                    dist_params_grads,
                    dist_startup_prog,
                    dist_main_prog,
                    g_process_group_map,
                ) = self._get_dist_program(rank, dist_context)
381
                dist_programs[rank] = [dist_main_prog, g_process_group_map]
382 383 384 385

            # Do the mapping between the distributed program graph and the cluster graph
            rank_mapping_dict = mapping(dist_programs, self._cluster)
            rank_mapping = list(rank_mapping_dict.values())
386

387 388 389 390 391
            # Relaunch the training by using the rank mapping file
            with open(self._rank_mapping_path, "w") as rank_mapping_file:
                json.dump(rank_mapping, rank_mapping_file)

            enable_elastic = os.getenv("PADDLE_ENABLE_ELASTIC")
392 393 394 395 396
            enable_elastic = (
                True
                if enable_elastic and enable_elastic.lower() == 'true'
                else False
            )
397 398
            if enable_elastic:
                print("Auto mapping finished, now do elastic re-launch")
399 400 401
                sys.exit(
                    paddle.distributed.fleet.elastic.manager.ELASTIC_AUTO_PARALLEL_EXIT_CODE
                )
402 403 404

            original_cmd_args = os.getenv("PADDLE_ORIGINAL_CMD_ARGS")
            rank_mapping_args = " ".join(
405 406
                ["--rank_mapping_path", self._rank_mapping_path]
            )
407 408 409 410
            if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
                coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
            else:
                coverage_args = []
411 412 413 414 415 416 417 418 419 420 421 422
            new_cmd_args = (
                "-m paddle.distributed.fleet.launch"
                + " "
                + rank_mapping_args
                + " "
                + original_cmd_args
            )
            new_cmd = (
                [sys.executable, "-u"]
                + coverage_args
                + shlex.split(new_cmd_args)
            )
423 424
            new_process = subprocess.Popen(new_cmd)
            new_process.wait()
425 426 427
            assert (
                new_process.returncode == 0
            ), "Launch failed with rank mapping"
428 429 430 431 432
            print("Successfully do the second launch for auto mapping!")
            sys.exit(0)
        else:
            # Parallelization after the mapping pass
            rank = paddle.distributed.get_rank()
433 434
            dist_context = None
            searched_dist_context_path = os.getenv(
435 436
                "PADDLE_SEARCHED_DIST_CONTEXT_PATH", None
            )
437
            if searched_dist_context_path is not None:
438 439 440
                with open(
                    searched_dist_context_path, "rb"
                ) as dist_context_file:
441 442 443 444
                    saved_dist_context = pickle.load(dist_context_file)
                    dist_context = DistributedContext()
                    for op in self._main_program.global_block().ops:
                        dist_attr = saved_dist_context["ops_dist_attr"][
445 446
                            op.desc.id()
                        ]
447 448 449 450 451 452
                        dist_op = DistributedOperator(op, dist_attr)
                        dist_context.add_dist_op_for_program(dist_op)

                    vars = self._main_program.global_block().vars
                    for var in vars.values():
                        dist_attr = saved_dist_context["tensors_dist_attr"][
453 454
                            var.desc.id()
                        ]
455 456 457 458
                        dist_tensor = DistributedTensor(var, dist_attr)
                        dist_context.add_dist_tensor_for_program(dist_tensor)

                    dist_context._process_meshes = saved_dist_context[
459 460
                        "process_meshes"
                    ]
461 462 463 464 465 466 467 468

            else:
                if self._dist_strategy.auto_search:
                    serial_program_info = SerialProgramInfo(
                        self._main_program,
                        self._startup_program,
                        self._loss,
                        self._optimizer,
469 470 471 472 473 474 475 476 477 478
                        cluster=self._cluster,
                    )
                    planner = Planner(
                        serial_program_info,
                        self,
                        algorithm_config={
                            "name": "mcmc",
                            "max_search_times": 5,
                        },
                    )
479 480 481 482 483 484
                    dist_context, _ = planner.search()

            # rebuild g_process_group
            if dist_context is not None:
                pg0 = get_process_group(0)
                for process_mesh in dist_context._process_meshes:
485
                    pg0.add_ranks(process_mesh.process_ids)
486 487 488 489 490 491 492
            (
                dist_optimize_ops,
                dist_params_grads,
                dist_startup_prog,
                dist_main_prog,
                _,
            ) = self._get_dist_program(rank, dist_context, relaunch_phase=True)
493

494 495 496 497 498 499 500 501 502 503
            # NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner
            if self._dist_strategy.auto_search:
                is_pipeline = False
                for op in dist_main_prog.global_block().ops:
                    if op.type == "send_v2" or op.type == "recv_v2":
                        is_pipeline = True
                        break
                if is_pipeline:
                    with paddle.static.program_guard(dist_main_prog):
                        paddle.distributed.barrier()
504

505 506 507 508 509 510 511
            # Traverse different rank programs and traverse each op of them,
            # instantiate communication by process_mapping.
            all_process_groups = get_all_process_groups()
            for process_group in all_process_groups:
                if rank not in process_group.ranks:
                    continue
                process_group.instantiate()
C
caozhou 已提交
512

513 514
            # Copy distributed info to the default context
            set_default_distributed_context(self._dist_context)
Z
zhaoyingli 已提交
515

516 517 518
            # The last step: remove all distributed attributes to be compatible
            # with inference.
            self._remove_distributed_attrs(dist_main_prog)
519

520 521 522 523 524 525
            return (
                dist_optimize_ops,
                dist_params_grads,
                dist_startup_prog,
                dist_main_prog,
            )
526 527 528 529 530 531

    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
532 533 534 535 536 537 538
            if (
                k == "_main_program"
                or k == "_startup_program"
                or k == "_dist_context"
                or k == "_fleet"
                or k == "_loss"
            ):
539 540 541 542
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result