dist_context.py 45.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
from collections import defaultdict
17

18
from paddle.distributed.passes import PassContext
19 20 21
from paddle.fluid import core, framework
from paddle.fluid.framework import set_flags

22
from .dist_op import DistributedOperator
23
from .dist_tensor import DistributedTensor
24
from .process_mesh import ProcessMesh
25 26 27 28 29
from .utils import (
    __no_shape_var_type__,
    _copy_dist_attr_to_cpp,
    is_loss_grad_op,
)
30

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None


def get_default_distributed_context():
    global _g_default_distributed_context
    if _g_default_distributed_context is None:
        dist_context = DistributedContext()
        set_default_distributed_context(dist_context)
    return _g_default_distributed_context


def set_default_distributed_context(dist_context):
    global _g_default_distributed_context
    _g_default_distributed_context = dist_context


48 49 50 51
def _node_id(node):
    return (node.node.graph_id(), node.node.id())


52 53 54 55 56 57
class DistributedContext:
    """
    DistributedContext is used to collect related distributed information for program and graph.
    One auto-parallel run should use its own DistributedContext to avoid interfering other run.
    """

58 59 60 61 62 63 64 65 66 67 68
    def __init__(
        self,
        serial_main_prog=None,
        serial_startup_prog=None,
        serial_optimizer=None,
        serial_loss=None,
        feed_vars={},
        fetch_vars={},
        cluster=None,
        strategy=None,
    ):
69 70 71
        # Data members related to original programs (unchanged)
        self._original_serial_main_program = serial_main_prog
        self._original_serial_startup_program = serial_startup_prog
72
        self._original_serial_optimizer = serial_optimizer
73
        self._original_serial_loss = serial_loss
74 75
        self._original_serial_feed_vars = feed_vars
        self._original_serial_fetch_vars = fetch_vars
76 77 78 79

        # Data members related to programs (changed)
        self._serial_main_program = None
        self._serial_startup_program = None
80 81 82 83
        self._serial_loss = None
        self._serial_optimizer = None
        self._serial_feed_vars = {}
        self._serial_fetch_vars = {}
84
        self._lr_optimizer = None  # record the optimzier holding lr_scheduler
85 86

        # Data members related to the program
87 88
        self._dist_tensors_for_program = {}
        self._dist_ops_for_program = {}
89 90

        # Data members related to the graph
91
        self._serial_graph = None
92 93
        self._dist_tensors_for_graph = {}
        self._dist_ops_for_graph = {}
94 95
        self._node_id_to_tensor_id = {}
        self._node_id_to_op_id = {}
96

97
        # Data members related to the distributed programs
98
        # Distributed programs
99 100
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
101 102
        self._dist_op_context = DistributedOperatorContext()
        self._process_meshes = []
103

104
        self._cluster = cluster
105 106 107 108
        self._strategy = strategy

        # Pass Context
        self._pass_context = PassContext()
109
        self._block_state = BlockState()
110 111 112 113 114 115 116 117

        # Other data members
        self._serial_ordered_tensor_nodes = []
        self._serial_ordered_op_nodes = []
        self._serial_ordered_nodes = []
        # self._tensor_id_to_tensor_node_ids = {}

        self._is_initialized = False
118
        # TODO: need a better way to remove the following flag
119 120 121 122 123 124 125
        self._need_copy_dist_attr_to_graph = False
        self._backup_pass_context_stack = []
        self._backup_block_state_stack = []
        self._backup_dist_tensors_for_program_stack = []
        self._backup_dist_ops_for_program_stack = []
        self._backup_serial_main_program_stack = []
        self._backup_serial_startup_program_stack = []
126

127 128 129
        # flag whether scale gradient with dp size
        self._gradient_scale = True

130 131 132
        # A flag indicates whether the used parallelism is data parallel
        self._data_parallel = False

133
    @property
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    def serial_main_program(self):
        return self._serial_main_program

    @property
    def serial_startup_program(self):
        return self._serial_startup_program

    @property
    def serial_loss(self):
        return self._serial_loss

    @property
    def serial_optimizer(self):
        return self._serial_optimizer

149 150 151 152 153 154 155
    @property
    def serial_feed_vars(self):
        return self._serial_feed_vars

    @property
    def serial_fetch_vars(self):
        return self._serial_fetch_vars
156

157 158 159 160 161 162 163 164 165 166 167 168
    @property
    def dist_main_programs(self):
        return self._dist_main_programs

    @property
    def dist_startup_programs(self):
        return self._dist_startup_programs

    @property
    def cluster(self):
        return self._cluster

169 170 171 172
    @property
    def strategy(self):
        return self._strategy

173 174 175 176
    @property
    def serial_graph(self):
        return self._serial_graph

177 178 179 180
    @property
    def serial_ordered_nodes(self):
        return self._serial_ordered_nodes

181 182 183 184
    @property
    def process_meshes(self):
        return self._process_meshes

185 186 187 188
    @property
    def pass_context(self):
        return self._pass_context

189 190 191 192
    @property
    def dist_op_context(self):
        return self._dist_op_context

193 194 195 196
    @property
    def block_state(self):
        return self._block_state

197
    @property
198
    def has_annotation(self):
199
        return len(self._dist_tensors_for_program) or len(
200 201
            self._dist_ops_for_program
        )
202

203 204 205 206 207 208 209 210
    @property
    def gradient_scale(self):
        return self._gradient_scale

    @gradient_scale.setter
    def gradient_scale(self, gs):
        self._gradient_scale = gs

211 212 213 214 215 216 217 218
    @property
    def data_parallel(self):
        return self._data_parallel

    @data_parallel.setter
    def data_parallel(self, dp):
        self._data_parallel = dp

219 220
    def _backup_serial_info(self, mode):
        self._backup_serial_main_program_stack.append(
221 222
            self._serial_main_program.clone()
        )
223
        self._backup_serial_startup_program_stack.append(
224 225 226 227 228
            self._serial_startup_program.clone()
        )
        self._backup_pass_context_stack.append(
            copy.deepcopy(self._pass_context)
        )
229 230 231 232
        self._backup_block_state_stack.append(copy.deepcopy(self._block_state))

    def _backup_dist_info(self, mode):
        self._backup_dist_tensors_for_program_stack.append(
233 234
            copy.deepcopy(self._dist_tensors_for_program)
        )
235
        self._backup_dist_ops_for_program_stack.append(
236 237
            copy.deepcopy(self._dist_ops_for_program)
        )
238 239 240 241 242 243 244 245

    def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
        # Use this function carefully
        if serial:
            self._backup_serial_info(serial_mode)
        if dist:
            self._backup_dist_info(dist_mode)

246
    def _restore_serial_loss(self):
247 248
        if self._original_serial_loss:
            if isinstance(self._original_serial_loss, list):
249 250 251 252 253
                if len(self._original_serial_loss) == 1:
                    loss = self._original_serial_loss[0]
                    block_idx = loss.block.idx
                    var_name = loss.name
                    var = self._serial_main_program.blocks[
254 255
                        block_idx
                    ]._var_recursive(var_name)
256 257 258 259 260
                    self._serial_loss = var
                elif len(self._original_serial_loss) == 0:
                    self._serial_loss = []
                else:
                    raise ValueError("multi loss vars are not supported.")
261
            else:
262 263 264
                block_idx = self._original_serial_loss.block.idx
                var_name = self._original_serial_loss.name
                var = self._serial_main_program.blocks[
265 266
                    block_idx
                ]._var_recursive(var_name)
267 268
                self._serial_loss = var

269
    def _restore_serial_feed_vars(self):
270 271 272 273 274 275
        for key, var_list in self._original_serial_feed_vars.items():
            new_var_list = []
            for var in var_list:
                block_idx = var.block.idx
                var_name = var.name
                var = self._serial_main_program.blocks[
276 277
                    block_idx
                ]._var_recursive(var_name)
278 279 280
                new_var_list.append(var)
            self._serial_feed_vars[key] = new_var_list

281
    def _restore_serial_fetch_vars(self):
282 283
        for key, var_list in self._original_serial_fetch_vars.items():
            new_var_list = []
284 285 286 287 288 289 290 291
            # metrics is a list of list
            if key == "metrics":
                for inner_var_list in var_list:
                    new_inner_var_list = []
                    for var in inner_var_list:
                        block_idx = var.block.idx
                        var_name = var.name
                        var = self._serial_main_program.blocks[
292 293
                            block_idx
                        ]._var_recursive(var_name)
294 295 296 297 298 299 300
                        new_inner_var_list.append(var)
                    new_var_list.append(new_inner_var_list)
            else:
                for var in var_list:
                    block_idx = var.block.idx
                    var_name = var.name
                    var = self._serial_main_program.blocks[
301 302
                        block_idx
                    ]._var_recursive(var_name)
303
                    new_var_list.append(var)
304 305
            self._serial_fetch_vars[key] = new_var_list

306 307
    def _restore_serial_info(self, mode="to_backup"):
        if mode == "to_backup":
308 309
            self._serial_main_program = (
                self._backup_serial_main_program_stack.pop()
310
            )
311 312
            self._serial_startup_program = (
                self._backup_serial_startup_program_stack.pop()
313 314 315 316
            )
        elif mode == "to_original":
            assert self._original_serial_main_program is not None
            assert self._original_serial_startup_program is not None
317 318
            self._serial_main_program = (
                self._original_serial_main_program.clone()
319
            )
320 321
            self._serial_startup_program = (
                self._original_serial_startup_program.clone()
322 323 324 325 326 327
            )

        self._restore_serial_loss()
        self._restore_serial_feed_vars()
        self._restore_serial_fetch_vars()
        self._serial_optimizer = self._original_serial_optimizer
328 329 330 331 332
        self._pass_context = self._backup_pass_context_stack.pop()
        self._block_state = self._backup_block_state_stack.pop()

    def _restore_dist_info(self, mode="to_backup"):
        if mode == "to_backup":
333 334
            self._dist_tensors_for_program = (
                self._backup_dist_tensors_for_program_stack.pop()
335
            )
336 337
            self._dist_ops_for_program = (
                self._backup_dist_ops_for_program_stack.pop()
338 339 340 341 342
            )
        elif mode == "to_original":
            assert self._original_dist_tensors_for_program
            assert self._original_dist_ops_for_program
            self._dist_tensors_for_program = copy.deepcopy(
343 344
                self._original_dist_tensors_for_program
            )
345
            self._dist_ops_for_program = copy.deepcopy(
346 347
                self._original_dist_ops_for_program
            )
348 349
        elif mode == "to_default":
            new_tensors_ids = []
350 351 352 353
            for (
                tensor_id,
                dist_tensor,
            ) in self._dist_tensors_for_program.items():
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369
                if tensor_id in self._tensors_ids:
                    dist_tensor.dist_attr.reset()
                else:
                    new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                if op_id in self._ops_ids:
                    dist_op.dist_attr.reset()
                else:
                    new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        else:
            new_tensors_ids = []
370 371 372 373
            for (
                tensor_id,
                dist_tensor,
            ) in self._dist_tensors_for_program.items():
374 375 376 377 378 379 380 381 382 383 384 385 386 387
                new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
        self._dist_op_context = DistributedOperatorContext()
        self._need_copy_dist_attr_to_graph = True
        self._process_meshes = []

388 389 390 391 392 393 394
    def _restore(
        self,
        serial=True,
        serial_mode="to_backup",
        dist=True,
        dist_mode="to_backup",
    ):
395 396 397 398 399 400
        # Use this function carefully
        if serial:
            self._restore_serial_info(serial_mode)
        if dist:
            self._restore_dist_info(dist_mode)

401
    def initialize(self, with_graph=True, with_cpp=False):
402 403
        if not self._is_initialized:
            if not self._serial_main_program:
404
                if self._original_serial_main_program:
405 406
                    self._serial_main_program = (
                        self._original_serial_main_program.clone()
407
                    )
408
            if not self._serial_startup_program:
409
                if self._original_serial_startup_program:
410 411
                    self._serial_startup_program = (
                        self._original_serial_startup_program.clone()
412
                    )
413
            if not self._serial_loss:
414
                self._restore_serial_loss()
415 416 417
            if not self._serial_optimizer:
                self._serial_optimizer = self._original_serial_optimizer
            if not self._serial_feed_vars:
418
                self._restore_serial_feed_vars()
419
            if not self._serial_fetch_vars:
420
                self._restore_serial_fetch_vars()
421

422
            self._init_dist_attr_for_program()
423 424
            # Backup the original distributed information for later restore
            self._original_dist_tensors_for_program = copy.deepcopy(
425 426
                self._dist_tensors_for_program
            )
427
            self._original_dist_ops_for_program = copy.deepcopy(
428 429
                self._dist_ops_for_program
            )
430 431 432
            self._tensors_ids = list(self._dist_tensors_for_program.keys())
            self._ops_ids = list(self._dist_ops_for_program.keys())
            self._is_initialized = True
433

434 435 436 437
            # TODO: This will be removed in the future
            if with_cpp:
                _copy_dist_attr_to_cpp(self)

438 439 440
            if with_graph:
                set_flags({"FLAGS_convert_all_blocks": True})
                self._serial_graph = framework.IrGraph(
441 442
                    core.Graph(self._serial_main_program.desc)
                )
443 444 445 446
                self._init_dist_attr_for_graph()
                self._need_copy_dist_attr_to_graph = False

        if self._need_copy_dist_attr_to_graph and with_graph:
447
            self.copy_dist_attr_from_program_to_graph()
448

449
    def add_process_mesh(self, process_mesh):
450 451 452
        assert isinstance(
            process_mesh, ProcessMesh
        ), 'The type of dim_mapping must be ProcessMesh.'
453 454 455 456 457
        if process_mesh not in self.process_meshes:
            self._process_meshes.append(process_mesh)

    def add_dist_tensor_for_program(self, dist_tensor):
        inner_serial_tensor = dist_tensor.serial_tensor
458
        inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
459 460 461 462
        self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor

    def add_dist_op_for_program(self, dist_op):
        inner_serial_op = dist_op.serial_op
463
        inner_serial_op_id = inner_serial_op.desc.original_id()
464 465 466 467
        self._dist_ops_for_program[inner_serial_op_id] = dist_op

    def get_dist_tensor_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
468 469 470 471 472
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor
        else:
            serial_tensor_id = serial_tensor.desc.original_id()
473
            dist_tensor = self._dist_tensors_for_program.get(
474 475
                serial_tensor_id, None
            )
476 477 478 479
            if dist_tensor:
                return dist_tensor
            else:
                return None
480 481

    def get_dist_tensor_for_graph(self, serial_tensor_node):
482
        serial_tensor_node_id = _node_id(serial_tensor_node)
483 484
        return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)

485 486 487 488 489 490 491 492 493 494 495 496
    def get_dist_op_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op
        else:
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op
            else:
                return None
497

498 499 500 501 502
    def del_dist_op_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        if self._dist_ops_for_program.get(serial_tensor_id, None):
            del self._dist_ops_for_program[serial_tensor_id]

503
    def get_dist_op_for_graph(self, serial_op_node):
504
        serial_op_node_id = _node_id(serial_op_node)
505
        return self._dist_ops_for_graph.get(serial_op_node_id, None)
506 507 508 509 510 511 512

    def get_tensor_dist_attr_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
513
            serial_tensor_id = serial_tensor.desc.original_id()
514
            dist_tensor = self._dist_tensors_for_program.get(
515 516
                serial_tensor_id, None
            )
517 518 519 520
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
521

522 523 524 525 526 527 528
    def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
        dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

529 530 531 532 533
    def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
        dist_tensor = DistributedTensor(serial_tensor, dist_attr)
        self.add_dist_tensor_for_program(dist_tensor)

    def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
534
        serial_tensor_node_id = _node_id(serial_tensor_node)
535 536 537
        dist_tensor = self._dist_tensors_for_graph.get(
            serial_tensor_node_id, None
        )
538 539 540 541 542 543 544 545 546 547 548
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

    def get_op_dist_attr_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
549 550 551 552 553 554
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
555

556 557 558 559 560 561 562
    def get_op_dist_attr_for_program_with_id(self, op_id):
        dist_op = self._dist_ops_for_program.get(op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

563 564 565 566 567
    def set_op_dist_attr_for_program(self, serial_op, dist_attr):
        dist_op = DistributedOperator(serial_op, dist_attr)
        self.add_dist_op_for_program(dist_op)

    def get_op_dist_attr_for_graph(self, serial_op_node):
568
        serial_op_node_id = _node_id(serial_op_node)
569 570 571 572 573 574
        dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

575 576
    def get_dist_attr_for_graph(self, serial_node):
        if serial_node.is_var() and serial_node.var() is not None:
577
            serial_tensor_node_id = _node_id(serial_node)
578
            dist_tensor = self._dist_tensors_for_graph.get(
579 580
                serial_tensor_node_id, None
            )
581 582 583 584 585
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
        if serial_node.is_op() and serial_node.op() is not None:
586
            serial_op_node_id = _node_id(serial_node)
587 588 589 590 591 592
            dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
        return None
593

594
    def _init_dist_attr_for_program(self, no_default=False):
595
        # Copy the dist tensors and dist ops annotated by users from the default context
596 597 598 599 600
        if not no_default:
            default_ctx = get_default_distributed_context()
            self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
        else:
            default_ctx = self
601 602
        # Copy the data parallel flag from the default context
        self._data_parallel = default_ctx.data_parallel
603
        for block in self._serial_main_program.blocks:
604 605 606
            for tensor in block.vars.values():
                # Copy the distributed tensors in the default context
                default_dist_tensor = default_ctx.get_dist_tensor_for_program(
607 608
                    tensor
                )
609
                if default_dist_tensor and default_ctx is not self:
610 611 612 613 614
                    dist_tensor = DistributedTensor(tensor)
                    dist_tensor.dist_attr = copy.deepcopy(
                        default_dist_tensor.dist_attr
                    )
                    self.add_dist_tensor_for_program(dist_tensor)
615 616 617 618 619 620 621 622
                current_dist_tensor = self.get_dist_tensor_for_program(tensor)
                if current_dist_tensor is None:
                    dist_tensor = DistributedTensor(tensor)
                    self.add_dist_tensor_for_program(dist_tensor)
            for op in block.ops:
                # Copy the distributed operators in the default context
                default_dist_op = default_ctx.get_dist_op_for_program(op)
                if default_dist_op and default_ctx is not self:
623 624 625
                    dist_op = DistributedOperator(op)
                    dist_op.dist_attr = copy.deepcopy(default_dist_op.dist_attr)
                    self.add_dist_op_for_program(dist_op)
626 627 628 629
                current_dist_op = self.get_dist_op_for_program(op)
                if current_dist_op is None:
                    dist_op = DistributedOperator(op)
                    self.add_dist_op_for_program(dist_op)
630
        self._original_dist_tensors_for_program = copy.deepcopy(
631 632
            self._dist_tensors_for_program
        )
633
        self._original_dist_ops_for_program = copy.deepcopy(
634 635
            self._dist_ops_for_program
        )
636

637
    def _order_nodes_by_program_order(self):
638 639
        def _contains(nodes, target_node):
            for node in nodes:
640
                if _node_id(node) == _node_id(target_node):
641 642 643
                    return True
            return False

644 645 646 647 648 649
        serial_ordered_tensor_nodes = []
        serial_ordered_op_nodes = []
        all_nodes = []
        for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
            for node in graph.all_nodes():
                all_nodes.append(node)
650 651
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
652
                serial_ordered_tensor_nodes.append(node)
653
            if node.is_op() and node.op() is not None:
654 655
                serial_ordered_op_nodes.append(node)
        serial_ordered_tensor_nodes.sort(
656 657
            key=lambda node: node.node.original_desc_id()
        )
658
        serial_ordered_op_nodes.sort(
659 660
            key=lambda node: node.node.original_desc_id()
        )
661
        num_nodes_before = len(serial_ordered_tensor_nodes) + len(
662 663
            serial_ordered_op_nodes
        )
664 665 666

        new_serial_ordered_tensor_nodes = []
        new_serial_ordered_op_nodes = []
667
        new_serial_ordered_nodes = []
668
        for op_node in serial_ordered_op_nodes:
669 670
            tensor_nodes = []
            for tensor_node in op_node.inputs:
671 672 673 674 675
                if (
                    tensor_node.is_var()
                    and tensor_node.var() is not None
                    and not _contains(new_serial_ordered_nodes, tensor_node)
                ):
676
                    tensor_nodes.append(tensor_node)
677
                    new_serial_ordered_tensor_nodes.append(tensor_node)
678
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
679 680
            new_serial_ordered_nodes.extend(tensor_nodes)
            new_serial_ordered_nodes.append(op_node)
681
            new_serial_ordered_op_nodes.append(op_node)
682 683
            tensor_nodes = []
            for tensor_node in op_node.outputs:
684 685 686 687 688
                if (
                    tensor_node.is_var()
                    and tensor_node.var() is not None
                    and not _contains(new_serial_ordered_nodes, tensor_node)
                ):
689
                    tensor_nodes.append(tensor_node)
690 691
                    new_serial_ordered_tensor_nodes.append(tensor_node)
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
692
            new_serial_ordered_nodes.extend(tensor_nodes)
693
        new_serial_ordered_tensor_nodes.sort(
694 695
            key=lambda node: node.node.original_desc_id()
        )
696
        new_serial_ordered_op_nodes.sort(
697 698
            key=lambda node: node.node.original_desc_id()
        )
699 700
        self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
        self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
701
        self._serial_ordered_nodes = new_serial_ordered_nodes
702
        assert len(self._serial_ordered_nodes) == len(
703 704
            self._serial_ordered_tensor_nodes
        ) + len(self._serial_ordered_op_nodes)
705 706 707 708 709 710 711 712
        self._serial_orphan_tensor_nodes = []
        for tensor_node in serial_ordered_tensor_nodes:
            if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
                self._serial_orphan_tensor_nodes.append(tensor_node)
        if len(self._serial_ordered_nodes) != num_nodes_before:
            print(
                "WARNING: there are some orphan tensors or ops which are not used in the execution."
            )
713

714 715 716
    def _init_dist_attr_for_graph(self):
        # Convert program to graph and initialize the distributed attributes
        self._order_nodes_by_program_order()
717
        for node in self.serial_ordered_nodes:
718
            if node.is_var() and node.var() is not None:
719 720
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
721 722 723 724 725 726 727 728 729
                for (
                    cur_tensor_id,
                    cur_dist_tensor,
                ) in self._dist_tensors_for_program.items():
                    if (
                        tensor_id == cur_tensor_id
                        or tensor_id
                        == cur_dist_tensor.serial_tensor.desc.original_id()
                    ):
730
                        dist_tensor = cur_dist_tensor
731 732 733 734 735 736
                        self._node_id_to_tensor_id[
                            _node_id(node)
                        ] = cur_tensor_id
                assert (
                    dist_tensor is not None
                ), "Tensor must have a distributed tensor after the initialization for program."
737
                serial_tensor_node_id = _node_id(node)
738 739 740
                new_dist_tensor = DistributedTensor(
                    dist_tensor.serial_tensor, dist_tensor.dist_attr
                )
741
                self._dist_tensors_for_graph[
742 743
                    serial_tensor_node_id
                ] = new_dist_tensor
744
            if node.is_op() and node.op() is not None:
745 746
                dist_op = None
                op_id = node.node.original_desc_id()
747 748 749 750 751 752 753 754
                for (
                    cur_op_id,
                    cur_dist_op,
                ) in self._dist_ops_for_program.items():
                    if (
                        op_id == cur_op_id
                        or op_id == cur_dist_op.serial_op.desc.original_id()
                    ):
755
                        dist_op = cur_dist_op
756
                        self._node_id_to_op_id[_node_id(node)] = cur_op_id
757 758 759
                assert (
                    dist_op is not None
                ), "Operator must have a distributed operator after the initialization for program."
760
                serial_op_node_id = _node_id(node)
761 762 763
                new_dist_op = DistributedOperator(
                    dist_op.serial_op, dist_op.dist_attr
                )
764
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
765 766 767 768 769 770 771 772 773

    def clear_dist_info_for_program(self):
        self._dist_tensors_for_program.clear()
        self._dist_ops_for_program.clear()

    def clear_dist_info_for_graph(self):
        self._dist_tensors_for_graph.clear()
        self._dist_ops_for_graph.clear()

774 775 776 777 778
    def copy_dist_attr_from_program_to_graph(self):
        for node in self.serial_ordered_nodes:
            if node.is_var() and node.var() is not None:
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
779 780 781 782 783 784 785 786 787
                for (
                    cur_tensor_id,
                    cur_dist_tensor,
                ) in self._dist_tensors_for_program.items():
                    if (
                        tensor_id == cur_tensor_id
                        or tensor_id
                        == cur_dist_tensor.serial_tensor.desc.original_id()
                    ):
788
                        dist_tensor = cur_dist_tensor
789 790 791
                assert (
                    dist_tensor is not None
                ), "Tensor must have a distributed tensor after the initialization for program."
792
                serial_tensor_node_id = _node_id(node)
793 794 795
                new_dist_tensor = DistributedTensor(
                    dist_tensor.serial_tensor, dist_tensor.dist_attr
                )
796
                self._dist_tensors_for_graph[
797 798
                    serial_tensor_node_id
                ] = new_dist_tensor
799 800 801
            if node.is_op() and node.op() is not None:
                dist_op = None
                op_id = node.node.original_desc_id()
802 803 804 805 806 807 808 809
                for (
                    cur_op_id,
                    cur_dist_op,
                ) in self._dist_ops_for_program.items():
                    if (
                        op_id == cur_op_id
                        or op_id == cur_dist_op.serial_op.desc.original_id()
                    ):
810
                        dist_op = cur_dist_op
811 812 813
                assert (
                    dist_op is not None
                ), "Operator must have a distributed operator after the initialization for program."
814
                serial_op_node_id = _node_id(node)
815 816 817
                new_dist_op = DistributedOperator(
                    dist_op.serial_op, dist_op.dist_attr
                )
818 819
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

820
    def copy_dist_attr_from_graph_to_program(self):
821 822 823
        assert (
            self._is_initialized
        ), "Both program and graph must be initialized."
824
        updated_tensors = {}
825 826
        # all_nodes = self._serial_graph.all_nodes()
        all_nodes = self._serial_ordered_nodes
827 828
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
829
                tensor_id = self._node_id_to_tensor_id[_node_id(node)]
830
                updated = updated_tensors.get(tensor_id, False)
831 832
                # If a var has multiples var nodes in graph, only use the first one for now
                if not updated:
833 834 835
                    tensor_dist_attr_for_graph = (
                        self.get_tensor_dist_attr_for_graph(node)
                    )
836
                    dist_tensor_for_program = self._dist_tensors_for_program[
837 838 839 840 841
                        tensor_id
                    ]
                    dist_tensor_for_program.dist_attr = (
                        tensor_dist_attr_for_graph
                    )
842
                    updated_tensors[tensor_id] = True
843
            if node.is_op() and node.op() is not None:
844
                op_id = self._node_id_to_op_id[_node_id(node)]
845 846 847
                op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
                dist_op_for_program = self._dist_ops_for_program[op_id]
                dist_op_for_program.dist_attr = op_dist_attr_for_graph
848
        # TODO: the completion algorithm will skipped orphan tensors,
849 850 851
        # here we just set there process_mesh to the first one.
        for orphan_node in self._serial_orphan_tensor_nodes:
            serial_tensor_id = orphan_node.var().id()
852
            dist_tensor = self._dist_tensors_for_program.get(
853 854
                serial_tensor_id, None
            )
855 856 857 858 859
            if dist_tensor:
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
            else:
                serial_tensor_id = orphan_node.var().original_id()
                dist_tensor = self._dist_tensors_for_program.get(
860 861
                    serial_tensor_id, None
                )
862
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
863 864 865 866 867

    def amend_dist_attr_for_program(self):
        for dist_tensor in self._dist_tensors_for_program.values():
            serial_tensor = dist_tensor.serial_tensor
            dist_attr = dist_tensor.dist_attr
Z
zhaoyingli 已提交
868
            if serial_tensor.type in __no_shape_var_type__:
869 870 871 872 873
                tensor_shape = []
            else:
                tensor_shape = serial_tensor.shape
            dims_mapping = dist_attr.dims_mapping
            process_mesh_shape = dist_attr.process_mesh.topology
874
            process_mesh_processes = dist_attr.process_mesh.processes
875 876 877
            # If the dimension of tensor is less than the sharding dimension of process mesh,
            # we just amend the dimension mapping to -1. (Is this really OK?)
            for i in range(len(tensor_shape)):
878 879 880 881 882
                if (
                    dims_mapping[i] != -1
                    and tensor_shape[i] > 0
                    and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]
                ):
883
                    dims_mapping[i] = -1
884 885
                if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
                    dims_mapping[i] = -1
886 887 888 889

        for dist_op in self._dist_ops_for_program.values():
            serial_op = dist_op.serial_op
            dist_attr = dist_op.dist_attr
890 891
            process_mesh_shape = dist_attr.process_mesh.topology
            process_mesh_processes = dist_attr.process_mesh.processes
892 893 894 895
            for arg_name in serial_op.input_arg_names:
                if dist_op.get_serial_input(arg_name) is None:
                    tensor_shape = []
                else:
896 897
                    if (
                        dist_op.get_serial_input(arg_name).type
Z
zhaoyingli 已提交
898
                        in __no_shape_var_type__
899
                    ):
900 901 902 903 904 905 906
                        tensor_shape = []
                    else:
                        tensor_shape = dist_op.get_serial_input(arg_name).shape
                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
907 908 909 910 911 912
                    if (
                        dims_mapping[i] != -1
                        and tensor_shape[i] > 0
                        and process_mesh_shape[dims_mapping[i]]
                        > tensor_shape[i]
                    ):
913
                        dims_mapping[i] = -1
914 915 916 917
                    if (
                        dims_mapping[i] != -1
                        and len(process_mesh_processes) == 1
                    ):
918
                        dims_mapping[i] = -1
919
            for arg_name in serial_op.output_arg_names:
920 921
                if (
                    dist_op.get_serial_output(arg_name).type
Z
zhaoyingli 已提交
922
                    in __no_shape_var_type__
923
                ):
924 925 926 927 928 929 930
                    tensor_shape = []
                else:
                    tensor_shape = dist_op.get_serial_output(arg_name).shape
                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
931 932 933 934 935 936
                    if (
                        dims_mapping[i] != -1
                        and tensor_shape[i] > 0
                        and process_mesh_shape[dims_mapping[i]]
                        > tensor_shape[i]
                    ):
937
                        dims_mapping[i] = -1
938 939 940 941
                    if (
                        dims_mapping[i] != -1
                        and len(process_mesh_processes) == 1
                    ):
942 943 944 945
                        dims_mapping[i] = -1
            if len(process_mesh_processes) == 1:
                dist_op.dist_attr.impl_type = "default"
                dist_op.dist_attr.impl_idx = 0
946 947

    def validate_dist_attr_for_program(self):
948
        if not self._is_initialized:
949 950 951
            assert (
                False
            ), "Program must be initialized before validating its distributed attributes"
952
        for block in self.serial_main_program.blocks:
953 954
            for tensor in block.vars.values():
                dist_tensor = self.get_dist_tensor_for_program(tensor)
955 956 957 958 959 960 961 962 963 964 965
                assert (
                    dist_tensor is not None
                ), "Tensor {} does not have a distributed attribute.".format(
                    dist_tensor.serial_tensor.name
                )
                if (dist_tensor is not None) and (
                    not dist_tensor.validate_dist_attr()
                ):
                    assert (
                        False
                    ), "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
C
caozhou 已提交
966 967 968
                        dist_tensor.serial_tensor.name,
                        dist_tensor.serial_tensor.desc.id(),
                        dist_tensor.serial_tensor.desc.original_id(),
969 970
                        dist_tensor.dist_attr,
                    )
971 972
            for op in block.ops:
                dist_op = self.get_dist_op_for_program(op)
973 974 975 976 977
                assert (
                    dist_op is not None
                ), "Operator {} does not have a distributed attribute.".format(
                    dist_op.serial_op.type
                )
978
                if (dist_op is not None) and (not dist_op.validate_dist_attr()):
979 980 981 982 983 984 985 986
                    assert (
                        False
                    ), "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
                        dist_op.serial_op.type,
                        dist_op.serial_op.desc.id(),
                        dist_op.serial_op.desc.original_id(),
                        dist_op.dist_attr,
                    )
987 988
        return True

Z
zhaoyingli 已提交
989 990 991 992 993
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
994
            if k in [
995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015
                "_original_serial_main_program",
                "_original_serial_startup_program",
                "_serial_main_program",
                "_serial_startup_program",
                "_serial_graph",
                "_dist_main_programs",
                "_dist_startup_programs",
                "_serial_ordered_nodes",
                "_serial_ordered_tensor_nodes",
                "_serial_ordered_op_nodes",
                "_original_serial_loss",
                "_original_serial_feed_vars",
                "_original_serial_fetch_vars",
                "_serial_loss",
                "_serial_feed_vars",
                "_serial_fetch_vars",
                "_serial_optimizer",
                "_backup_serial_main_program_stack",
                "_backup_serial_startup_program_stack",
                "_pass_context",
            ]:
Z
zhaoyingli 已提交
1016 1017 1018
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
1019 1020 1021 1022

        # update dist tensor's dist_context
        for key in result._dist_tensors_for_program.keys():
            result._dist_tensors_for_program[key]._dist_context = result
Z
zhaoyingli 已提交
1023 1024
        return result

1025 1026 1027 1028 1029 1030 1031 1032 1033

class DistributedOperatorContext:
    """
    DistributedOperatorContext is used to create a dist op desc in Program.
    Every time to create a new dist op, the context should be updated for it accordingly.
    """

    def __init__(self):
        self._dst_main_program = None
1034
        self._main_block = None
1035
        self._dst_startup_program = None
1036
        self._startup_block = None
1037 1038
        self._cur_src_op = None
        self._cur_dist_attr = None
1039
        self.grad_op_id_to_op_id = {}
1040
        self.grad_var_to_var = defaultdict(dict)
1041
        self._work_block = None
1042
        self.already_init_sync_vars = set()
1043 1044
        self.varname_mapping = None
        self.rank_id = None
1045 1046 1047 1048 1049
        # NOTE Support correct parallelism for high-order differential model.
        # by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True,
        # it means we are in Backward phase.
        # And the final sulotion should be revise high-order differential logic for these two phases in future.
        self._exceed_backward_init_op = False
1050

Z
zhaoyingli 已提交
1051 1052 1053 1054 1055
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
1056
            if k in [
1057 1058 1059 1060 1061 1062
                "_dst_main_program",
                "_dst_startup_program",
                "_cur_src_op",
                "_work_block",
                "_main_block",
                "_startup_block",
1063
            ]:
Z
zhaoyingli 已提交
1064 1065 1066 1067 1068
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

1069 1070
    @property
    def dst_main_program(self):
1071 1072
        return self._dst_main_program

1073 1074 1075 1076
    @dst_main_program.setter
    def dst_main_program(self, prog):
        self._dst_main_program = prog
        self._main_block = prog.blocks[0]
1077

1078 1079 1080
    @property
    def main_block(self):
        return self._main_block
1081

1082 1083 1084
    @property
    def dst_startup_program(self):
        return self._dst_startup_program
1085

1086 1087 1088 1089
    @dst_startup_program.setter
    def dst_startup_program(self, prog):
        self._dst_startup_program = prog
        self._startup_block = prog.blocks[0]
1090

1091 1092 1093
    @property
    def startup_block(self):
        return self._startup_block
1094

1095 1096 1097 1098
    @property
    def work_block(self):
        assert self._work_block is not None
        return self._work_block
1099

1100 1101 1102 1103
    @work_block.setter
    def work_block(self, block):
        assert block is not None
        self._work_block = block
1104

1105 1106 1107
    @property
    def cur_src_op(self):
        assert self._cur_src_op is not None
1108 1109
        return self._cur_src_op

1110 1111 1112
    def in_backward_phase(self):
        return self._exceed_backward_init_op

1113
    def prepare_context(self, src_op):
1114

1115
        self._cur_src_op = src_op
1116

1117 1118 1119
        if is_loss_grad_op(src_op):
            self._exceed_backward_init_op = True

1120 1121 1122 1123 1124
        # build input varname mapping
        kinputs = {}
        for input_name in src_op.desc.input_names():
            varnames = []
            for varname in src_op.desc.input(input_name):
1125 1126
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
1127 1128 1129 1130 1131 1132 1133
            kinputs[input_name] = varnames

        # build output varname mapping
        koutputs = {}
        for output_name in src_op.desc.output_names():
            varnames = []
            for varname in src_op.desc.output(output_name):
1134 1135
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
1136 1137 1138
            koutputs[output_name] = varnames

        return kinputs, koutputs
1139 1140


1141
class BlockState:
1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157
    def __init__(self):
        self.nblock = 0
        self.forward_indices = []
        self.backward_indices = []
        self.backward_to_forward_index_map = {}

    def parse_forward_blocks(self, program):

        while program.current_block_idx != 0:
            program._rollback()

        assert program.current_block_idx == 0

        for idx, block in enumerate(program.blocks):

            assert idx == block.idx, "index doesn't match"
1158 1159 1160 1161 1162
            assert (
                block.forward_block_idx == -1
            ), "forward_block_idx of forward block [{}] is not [{}]".format(
                idx, block.forward_block_idx
            )
1163 1164 1165 1166 1167 1168 1169 1170
            self.forward_indices.append(idx)
            self.nblock += 1

        assert self.nblock >= 1

    def parse_backward_blocks(self, program):

        assert 0 in self.forward_indices, "forward block idx are{}".format(
1171 1172
            self.forward_indices
        )
1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
        self.backward_to_forward_index_map[0] = 0

        for idx, block in enumerate(program.blocks):

            if idx < len(self.forward_indices):
                continue

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx in self.forward_indices
            self.backward_indices.append(idx)
            self.backward_to_forward_index_map[idx] = block.forward_block_idx
            self.nblock += 1

        assert self.nblock == len(program.blocks)