dist_context.py 47.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
from collections import defaultdict
17

18
from paddle.distributed.passes import PassContext
19
from paddle.framework import IrGraph, core, set_flags
20

21
from .dist_op import DistributedOperator
22
from .dist_tensor import DistributedTensor
23
from .process_mesh import ProcessMesh
24 25 26 27 28
from .utils import (
    __no_shape_var_type__,
    _copy_dist_attr_to_cpp,
    is_loss_grad_op,
)
29

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None


def get_default_distributed_context():
    global _g_default_distributed_context
    if _g_default_distributed_context is None:
        dist_context = DistributedContext()
        set_default_distributed_context(dist_context)
    return _g_default_distributed_context


def set_default_distributed_context(dist_context):
    global _g_default_distributed_context
    _g_default_distributed_context = dist_context


47 48 49 50
def _node_id(node):
    return (node.node.graph_id(), node.node.id())


51 52 53 54 55 56
class DistributedContext:
    """
    DistributedContext is used to collect related distributed information for program and graph.
    One auto-parallel run should use its own DistributedContext to avoid interfering other run.
    """

57 58 59 60 61 62 63 64 65 66 67
    def __init__(
        self,
        serial_main_prog=None,
        serial_startup_prog=None,
        serial_optimizer=None,
        serial_loss=None,
        feed_vars={},
        fetch_vars={},
        cluster=None,
        strategy=None,
    ):
68 69 70
        # Data members related to original programs (unchanged)
        self._original_serial_main_program = serial_main_prog
        self._original_serial_startup_program = serial_startup_prog
71
        self._original_serial_optimizer = serial_optimizer
72
        self._original_serial_loss = serial_loss
73 74
        self._original_serial_feed_vars = feed_vars
        self._original_serial_fetch_vars = fetch_vars
75 76 77 78

        # Data members related to programs (changed)
        self._serial_main_program = None
        self._serial_startup_program = None
79 80 81 82
        self._serial_loss = None
        self._serial_optimizer = None
        self._serial_feed_vars = {}
        self._serial_fetch_vars = {}
83
        self._lr_optimizer = None  # record the optimzier holding lr_scheduler
84 85

        # Data members related to the program
86 87
        self._dist_tensors_for_program = {}
        self._dist_ops_for_program = {}
88 89

        # Data members related to the graph
90
        self._serial_graph = None
91 92
        self._dist_tensors_for_graph = {}
        self._dist_ops_for_graph = {}
93 94
        self._node_id_to_tensor_id = {}
        self._node_id_to_op_id = {}
95

96
        # Data members related to the distributed programs
97
        # Distributed programs
98 99
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
100 101
        self._dist_op_context = DistributedOperatorContext()
        self._process_meshes = []
102

103
        self._cluster = cluster
104 105 106 107
        self._strategy = strategy

        # Pass Context
        self._pass_context = PassContext()
108
        self._block_state = BlockState()
109 110 111 112 113 114 115 116

        # Other data members
        self._serial_ordered_tensor_nodes = []
        self._serial_ordered_op_nodes = []
        self._serial_ordered_nodes = []
        # self._tensor_id_to_tensor_node_ids = {}

        self._is_initialized = False
117
        # TODO: need a better way to remove the following flag
118 119 120 121 122 123 124
        self._need_copy_dist_attr_to_graph = False
        self._backup_pass_context_stack = []
        self._backup_block_state_stack = []
        self._backup_dist_tensors_for_program_stack = []
        self._backup_dist_ops_for_program_stack = []
        self._backup_serial_main_program_stack = []
        self._backup_serial_startup_program_stack = []
125

126 127 128
        # flag whether scale gradient with dp size
        self._gradient_scale = True

129 130 131
        # A flag indicates whether the used parallelism is data parallel
        self._data_parallel = False

132
    @property
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    def serial_main_program(self):
        return self._serial_main_program

    @property
    def serial_startup_program(self):
        return self._serial_startup_program

    @property
    def serial_loss(self):
        return self._serial_loss

    @property
    def serial_optimizer(self):
        return self._serial_optimizer

148 149 150 151 152 153 154
    @property
    def serial_feed_vars(self):
        return self._serial_feed_vars

    @property
    def serial_fetch_vars(self):
        return self._serial_fetch_vars
155

156 157 158 159 160 161 162 163 164 165 166 167
    @property
    def dist_main_programs(self):
        return self._dist_main_programs

    @property
    def dist_startup_programs(self):
        return self._dist_startup_programs

    @property
    def cluster(self):
        return self._cluster

168 169 170 171
    @property
    def strategy(self):
        return self._strategy

172 173 174 175
    @property
    def serial_graph(self):
        return self._serial_graph

176 177 178 179
    @property
    def serial_ordered_nodes(self):
        return self._serial_ordered_nodes

180 181 182 183
    @property
    def process_meshes(self):
        return self._process_meshes

184 185 186 187
    @property
    def pass_context(self):
        return self._pass_context

188 189 190 191
    @property
    def dist_op_context(self):
        return self._dist_op_context

192 193 194 195
    @property
    def block_state(self):
        return self._block_state

196
    @property
197
    def has_annotation(self):
198
        return len(self._dist_tensors_for_program) or len(
199 200
            self._dist_ops_for_program
        )
201

202 203 204 205 206 207 208 209
    @property
    def gradient_scale(self):
        return self._gradient_scale

    @gradient_scale.setter
    def gradient_scale(self, gs):
        self._gradient_scale = gs

210 211 212 213 214 215 216 217
    @property
    def data_parallel(self):
        return self._data_parallel

    @data_parallel.setter
    def data_parallel(self, dp):
        self._data_parallel = dp

218 219
    def _backup_serial_info(self, mode):
        self._backup_serial_main_program_stack.append(
220 221
            self._serial_main_program.clone()
        )
222
        self._backup_serial_startup_program_stack.append(
223 224 225 226 227
            self._serial_startup_program.clone()
        )
        self._backup_pass_context_stack.append(
            copy.deepcopy(self._pass_context)
        )
228 229 230 231
        self._backup_block_state_stack.append(copy.deepcopy(self._block_state))

    def _backup_dist_info(self, mode):
        self._backup_dist_tensors_for_program_stack.append(
232 233
            copy.deepcopy(self._dist_tensors_for_program)
        )
234
        self._backup_dist_ops_for_program_stack.append(
235 236
            copy.deepcopy(self._dist_ops_for_program)
        )
237 238 239 240 241 242 243 244

    def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
        # Use this function carefully
        if serial:
            self._backup_serial_info(serial_mode)
        if dist:
            self._backup_dist_info(dist_mode)

245
    def _restore_serial_loss(self):
246 247
        if self._original_serial_loss:
            if isinstance(self._original_serial_loss, list):
248 249 250 251 252
                if len(self._original_serial_loss) == 1:
                    loss = self._original_serial_loss[0]
                    block_idx = loss.block.idx
                    var_name = loss.name
                    var = self._serial_main_program.blocks[
253 254
                        block_idx
                    ]._var_recursive(var_name)
255 256 257 258 259
                    self._serial_loss = var
                elif len(self._original_serial_loss) == 0:
                    self._serial_loss = []
                else:
                    raise ValueError("multi loss vars are not supported.")
260
            else:
261 262 263
                block_idx = self._original_serial_loss.block.idx
                var_name = self._original_serial_loss.name
                var = self._serial_main_program.blocks[
264 265
                    block_idx
                ]._var_recursive(var_name)
266 267
                self._serial_loss = var

268
    def _restore_serial_feed_vars(self):
269 270 271 272 273 274
        for key, var_list in self._original_serial_feed_vars.items():
            new_var_list = []
            for var in var_list:
                block_idx = var.block.idx
                var_name = var.name
                var = self._serial_main_program.blocks[
275 276
                    block_idx
                ]._var_recursive(var_name)
277 278 279
                new_var_list.append(var)
            self._serial_feed_vars[key] = new_var_list

280
    def _restore_serial_fetch_vars(self):
281 282
        for key, var_list in self._original_serial_fetch_vars.items():
            new_var_list = []
283 284 285 286 287 288 289 290
            # metrics is a list of list
            if key == "metrics":
                for inner_var_list in var_list:
                    new_inner_var_list = []
                    for var in inner_var_list:
                        block_idx = var.block.idx
                        var_name = var.name
                        var = self._serial_main_program.blocks[
291 292
                            block_idx
                        ]._var_recursive(var_name)
293 294 295 296 297 298 299
                        new_inner_var_list.append(var)
                    new_var_list.append(new_inner_var_list)
            else:
                for var in var_list:
                    block_idx = var.block.idx
                    var_name = var.name
                    var = self._serial_main_program.blocks[
300 301
                        block_idx
                    ]._var_recursive(var_name)
302
                    new_var_list.append(var)
303 304
            self._serial_fetch_vars[key] = new_var_list

305 306
    def _restore_serial_info(self, mode="to_backup"):
        if mode == "to_backup":
307 308
            self._serial_main_program = (
                self._backup_serial_main_program_stack.pop()
309
            )
310 311
            self._serial_startup_program = (
                self._backup_serial_startup_program_stack.pop()
312 313 314 315
            )
        elif mode == "to_original":
            assert self._original_serial_main_program is not None
            assert self._original_serial_startup_program is not None
316 317
            self._serial_main_program = (
                self._original_serial_main_program.clone()
318
            )
319 320
            self._serial_startup_program = (
                self._original_serial_startup_program.clone()
321 322 323 324 325 326
            )

        self._restore_serial_loss()
        self._restore_serial_feed_vars()
        self._restore_serial_fetch_vars()
        self._serial_optimizer = self._original_serial_optimizer
327 328 329 330 331
        self._pass_context = self._backup_pass_context_stack.pop()
        self._block_state = self._backup_block_state_stack.pop()

    def _restore_dist_info(self, mode="to_backup"):
        if mode == "to_backup":
332 333
            self._dist_tensors_for_program = (
                self._backup_dist_tensors_for_program_stack.pop()
334
            )
335 336
            self._dist_ops_for_program = (
                self._backup_dist_ops_for_program_stack.pop()
337 338 339 340 341
            )
        elif mode == "to_original":
            assert self._original_dist_tensors_for_program
            assert self._original_dist_ops_for_program
            self._dist_tensors_for_program = copy.deepcopy(
342 343
                self._original_dist_tensors_for_program
            )
344
            self._dist_ops_for_program = copy.deepcopy(
345 346
                self._original_dist_ops_for_program
            )
347 348
        elif mode == "to_default":
            new_tensors_ids = []
349 350 351 352
            for (
                tensor_id,
                dist_tensor,
            ) in self._dist_tensors_for_program.items():
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368
                if tensor_id in self._tensors_ids:
                    dist_tensor.dist_attr.reset()
                else:
                    new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                if op_id in self._ops_ids:
                    dist_op.dist_attr.reset()
                else:
                    new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        else:
            new_tensors_ids = []
369 370 371 372
            for (
                tensor_id,
                dist_tensor,
            ) in self._dist_tensors_for_program.items():
373 374 375 376 377 378 379 380 381 382 383 384 385 386
                new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
        self._dist_op_context = DistributedOperatorContext()
        self._need_copy_dist_attr_to_graph = True
        self._process_meshes = []

387 388 389 390 391 392 393
    def _restore(
        self,
        serial=True,
        serial_mode="to_backup",
        dist=True,
        dist_mode="to_backup",
    ):
394 395 396 397 398 399
        # Use this function carefully
        if serial:
            self._restore_serial_info(serial_mode)
        if dist:
            self._restore_dist_info(dist_mode)

400
    def initialize(self, with_graph=True, with_cpp=False):
401 402
        if not self._is_initialized:
            if not self._serial_main_program:
403
                if self._original_serial_main_program:
404 405
                    self._serial_main_program = (
                        self._original_serial_main_program.clone()
406
                    )
407
            if not self._serial_startup_program:
408
                if self._original_serial_startup_program:
409 410
                    self._serial_startup_program = (
                        self._original_serial_startup_program.clone()
411
                    )
412
            if not self._serial_loss:
413
                self._restore_serial_loss()
414 415 416
            if not self._serial_optimizer:
                self._serial_optimizer = self._original_serial_optimizer
            if not self._serial_feed_vars:
417
                self._restore_serial_feed_vars()
418
            if not self._serial_fetch_vars:
419
                self._restore_serial_fetch_vars()
420

421
            self._init_dist_attr_for_program()
422 423
            # Backup the original distributed information for later restore
            self._original_dist_tensors_for_program = copy.deepcopy(
424 425
                self._dist_tensors_for_program
            )
426
            self._original_dist_ops_for_program = copy.deepcopy(
427 428
                self._dist_ops_for_program
            )
429 430 431
            self._tensors_ids = list(self._dist_tensors_for_program.keys())
            self._ops_ids = list(self._dist_ops_for_program.keys())
            self._is_initialized = True
432

433 434 435 436
            # TODO: This will be removed in the future
            if with_cpp:
                _copy_dist_attr_to_cpp(self)

437 438
            if with_graph:
                set_flags({"FLAGS_convert_all_blocks": True})
439
                self._serial_graph = IrGraph(
440 441
                    core.Graph(self._serial_main_program.desc)
                )
442 443 444 445
                self._init_dist_attr_for_graph()
                self._need_copy_dist_attr_to_graph = False

        if self._need_copy_dist_attr_to_graph and with_graph:
446
            self.copy_dist_attr_from_program_to_graph()
447

448
    def add_process_mesh(self, process_mesh):
449
        assert isinstance(
450
            process_mesh, (ProcessMesh, core.ProcessMesh)
451
        ), 'The type of dim_mapping must be ProcessMesh.'
452 453 454 455 456
        if process_mesh not in self.process_meshes:
            self._process_meshes.append(process_mesh)

    def add_dist_tensor_for_program(self, dist_tensor):
        inner_serial_tensor = dist_tensor.serial_tensor
457
        inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
458 459 460 461
        self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor

    def add_dist_op_for_program(self, dist_op):
        inner_serial_op = dist_op.serial_op
462
        inner_serial_op_id = inner_serial_op.desc.original_id()
463 464 465 466
        self._dist_ops_for_program[inner_serial_op_id] = dist_op

    def get_dist_tensor_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
467 468 469 470 471
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor
        else:
            serial_tensor_id = serial_tensor.desc.original_id()
472
            dist_tensor = self._dist_tensors_for_program.get(
473 474
                serial_tensor_id, None
            )
475 476 477 478
            if dist_tensor:
                return dist_tensor
            else:
                return None
479 480

    def get_dist_tensor_for_graph(self, serial_tensor_node):
481
        serial_tensor_node_id = _node_id(serial_tensor_node)
482 483
        return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)

484 485 486 487 488 489 490 491 492 493 494 495
    def get_dist_op_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op
        else:
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op
            else:
                return None
496

497 498 499 500 501
    def del_dist_op_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        if self._dist_ops_for_program.get(serial_tensor_id, None):
            del self._dist_ops_for_program[serial_tensor_id]

502
    def get_dist_op_for_graph(self, serial_op_node):
503
        serial_op_node_id = _node_id(serial_op_node)
504
        return self._dist_ops_for_graph.get(serial_op_node_id, None)
505 506 507 508 509 510 511

    def get_tensor_dist_attr_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
512
            serial_tensor_id = serial_tensor.desc.original_id()
513
            dist_tensor = self._dist_tensors_for_program.get(
514 515
                serial_tensor_id, None
            )
516 517 518 519
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
520

521 522 523 524 525 526 527
    def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
        dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

528 529 530 531 532
    def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
        dist_tensor = DistributedTensor(serial_tensor, dist_attr)
        self.add_dist_tensor_for_program(dist_tensor)

    def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
533
        serial_tensor_node_id = _node_id(serial_tensor_node)
534 535 536
        dist_tensor = self._dist_tensors_for_graph.get(
            serial_tensor_node_id, None
        )
537 538 539 540 541 542 543 544 545 546 547
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

    def get_op_dist_attr_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
548 549 550 551 552 553
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
554

555 556 557 558 559 560 561
    def get_op_dist_attr_for_program_with_id(self, op_id):
        dist_op = self._dist_ops_for_program.get(op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

562 563 564 565 566
    def set_op_dist_attr_for_program(self, serial_op, dist_attr):
        dist_op = DistributedOperator(serial_op, dist_attr)
        self.add_dist_op_for_program(dist_op)

    def get_op_dist_attr_for_graph(self, serial_op_node):
567
        serial_op_node_id = _node_id(serial_op_node)
568 569 570 571 572 573
        dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

574 575
    def get_dist_attr_for_graph(self, serial_node):
        if serial_node.is_var() and serial_node.var() is not None:
576
            serial_tensor_node_id = _node_id(serial_node)
577
            dist_tensor = self._dist_tensors_for_graph.get(
578 579
                serial_tensor_node_id, None
            )
580 581 582 583 584
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
        if serial_node.is_op() and serial_node.op() is not None:
585
            serial_op_node_id = _node_id(serial_node)
586 587 588 589 590 591
            dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
        return None
592

593
    def _init_dist_attr_for_program(self, no_default=False):
594
        # Copy the dist tensors and dist ops annotated by users from the default context
595 596 597 598 599
        if not no_default:
            default_ctx = get_default_distributed_context()
            self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
        else:
            default_ctx = self
600 601
        # Copy the data parallel flag from the default context
        self._data_parallel = default_ctx.data_parallel
602
        for block in self._serial_main_program.blocks:
603 604 605
            for tensor in block.vars.values():
                # Copy the distributed tensors in the default context
                default_dist_tensor = default_ctx.get_dist_tensor_for_program(
606 607
                    tensor
                )
608
                if default_dist_tensor and default_ctx is not self:
609 610 611 612 613
                    dist_tensor = DistributedTensor(tensor)
                    dist_tensor.dist_attr = copy.deepcopy(
                        default_dist_tensor.dist_attr
                    )
                    self.add_dist_tensor_for_program(dist_tensor)
614 615 616 617 618 619 620 621
                current_dist_tensor = self.get_dist_tensor_for_program(tensor)
                if current_dist_tensor is None:
                    dist_tensor = DistributedTensor(tensor)
                    self.add_dist_tensor_for_program(dist_tensor)
            for op in block.ops:
                # Copy the distributed operators in the default context
                default_dist_op = default_ctx.get_dist_op_for_program(op)
                if default_dist_op and default_ctx is not self:
622 623 624
                    dist_op = DistributedOperator(op)
                    dist_op.dist_attr = copy.deepcopy(default_dist_op.dist_attr)
                    self.add_dist_op_for_program(dist_op)
625 626 627 628
                current_dist_op = self.get_dist_op_for_program(op)
                if current_dist_op is None:
                    dist_op = DistributedOperator(op)
                    self.add_dist_op_for_program(dist_op)
629
        self._original_dist_tensors_for_program = copy.deepcopy(
630 631
            self._dist_tensors_for_program
        )
632
        self._original_dist_ops_for_program = copy.deepcopy(
633 634
            self._dist_ops_for_program
        )
635

636
    def _order_nodes_by_program_order(self):
637 638 639
        serial_ordered_tensor_nodes = []
        serial_ordered_op_nodes = []
        all_nodes = []
640
        visited = {}
641 642 643
        for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
            for node in graph.all_nodes():
                all_nodes.append(node)
644 645
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
646
                serial_ordered_tensor_nodes.append(node)
647
                visited[_node_id(node)] = False
648
            if node.is_op() and node.op() is not None:
649 650
                serial_ordered_op_nodes.append(node)
        serial_ordered_tensor_nodes.sort(
651 652
            key=lambda node: node.node.original_desc_id()
        )
653
        serial_ordered_op_nodes.sort(
654 655
            key=lambda node: node.node.original_desc_id()
        )
656
        num_nodes_before = len(serial_ordered_tensor_nodes) + len(
657 658
            serial_ordered_op_nodes
        )
659 660 661

        new_serial_ordered_tensor_nodes = []
        new_serial_ordered_op_nodes = []
662
        new_serial_ordered_nodes = []
663
        for op_node in serial_ordered_op_nodes:
664 665
            tensor_nodes = []
            for tensor_node in op_node.inputs:
666 667 668
                if (
                    tensor_node.is_var()
                    and tensor_node.var() is not None
669
                    and not visited[_node_id(tensor_node)]
670
                ):
671
                    tensor_nodes.append(tensor_node)
672
                    new_serial_ordered_tensor_nodes.append(tensor_node)
673 674
                    visited[_node_id(tensor_node)] = True

675
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
676 677
            new_serial_ordered_nodes.extend(tensor_nodes)
            new_serial_ordered_nodes.append(op_node)
678
            new_serial_ordered_op_nodes.append(op_node)
679 680
            tensor_nodes = []
            for tensor_node in op_node.outputs:
681 682 683
                if (
                    tensor_node.is_var()
                    and tensor_node.var() is not None
684
                    and not visited[_node_id(tensor_node)]
685
                ):
686
                    tensor_nodes.append(tensor_node)
687
                    new_serial_ordered_tensor_nodes.append(tensor_node)
688
                    visited[_node_id(tensor_node)] = True
689
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
690
            new_serial_ordered_nodes.extend(tensor_nodes)
691
        new_serial_ordered_tensor_nodes.sort(
692 693
            key=lambda node: node.node.original_desc_id()
        )
694
        new_serial_ordered_op_nodes.sort(
695 696
            key=lambda node: node.node.original_desc_id()
        )
697 698
        self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
        self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
699
        self._serial_ordered_nodes = new_serial_ordered_nodes
700
        assert len(self._serial_ordered_nodes) == len(
701 702
            self._serial_ordered_tensor_nodes
        ) + len(self._serial_ordered_op_nodes)
703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721
        # graph_id -> tensor->name -> node_lists
        self._tensor_nodes_with_same_name = defaultdict(dict)
        for idx, node in enumerate(self._serial_ordered_nodes):
            if node.is_var() and node.var() is not None:
                graph_id = node.node.graph_id()
                tensor_name = node.var().name()
                if (
                    self._tensor_nodes_with_same_name[graph_id].get(
                        tensor_name, None
                    )
                    is None
                ):
                    self._tensor_nodes_with_same_name[graph_id][
                        tensor_name
                    ] = []
                self._tensor_nodes_with_same_name[graph_id][tensor_name].append(
                    (idx, node)
                )

722 723
        self._serial_orphan_tensor_nodes = []
        for tensor_node in serial_ordered_tensor_nodes:
724
            if not visited[_node_id(tensor_node)]:
725 726 727 728 729
                self._serial_orphan_tensor_nodes.append(tensor_node)
        if len(self._serial_ordered_nodes) != num_nodes_before:
            print(
                "WARNING: there are some orphan tensors or ops which are not used in the execution."
            )
730

731 732 733
    def _init_dist_attr_for_graph(self):
        # Convert program to graph and initialize the distributed attributes
        self._order_nodes_by_program_order()
734 735 736 737 738 739 740 741
        self._tensor_original_id_to_id = {}
        self._op_original_id_to_id = {}
        for tensor_id, tensor in self._dist_tensors_for_program.items():
            original_id = tensor.serial_tensor.desc.original_id()
            self._tensor_original_id_to_id[original_id] = tensor_id
        for op_id, op in self._dist_ops_for_program.items():
            original_id = op.serial_op.desc.original_id()
            self._op_original_id_to_id[original_id] = op_id
742
        for node in self.serial_ordered_nodes:
743
            if node.is_var() and node.var() is not None:
744 745
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
746 747 748 749 750 751 752 753 754 755 756 757
                cur_dist_tensor = self._dist_tensors_for_program.get(
                    tensor_id, None
                )
                if cur_dist_tensor is not None:
                    cur_tensor_id = tensor_id
                else:
                    cur_tensor_id = self._tensor_original_id_to_id[tensor_id]
                    cur_dist_tensor = self._dist_tensors_for_program.get(
                        cur_tensor_id, None
                    )
                dist_tensor = cur_dist_tensor
                self._node_id_to_tensor_id[_node_id(node)] = cur_tensor_id
758 759 760
                assert (
                    dist_tensor is not None
                ), "Tensor must have a distributed tensor after the initialization for program."
761
                serial_tensor_node_id = _node_id(node)
762 763 764
                new_dist_tensor = DistributedTensor(
                    dist_tensor.serial_tensor, dist_tensor.dist_attr
                )
765
                self._dist_tensors_for_graph[
766 767
                    serial_tensor_node_id
                ] = new_dist_tensor
768
            if node.is_op() and node.op() is not None:
769 770
                dist_op = None
                op_id = node.node.original_desc_id()
771 772 773 774 775 776 777 778 779 780
                cur_dist_op = self._dist_ops_for_program.get(op_id, None)
                if cur_dist_op is not None:
                    cur_op_id = op_id
                else:
                    cur_op_id = self._op_original_id_to_id[op_id]
                    cur_dist_op = self._dist_ops_for_program.get(
                        cur_op_id, None
                    )
                dist_op = cur_dist_op
                self._node_id_to_op_id[_node_id(node)] = cur_op_id
781 782 783
                assert (
                    dist_op is not None
                ), "Operator must have a distributed operator after the initialization for program."
784
                serial_op_node_id = _node_id(node)
785 786 787
                new_dist_op = DistributedOperator(
                    dist_op.serial_op, dist_op.dist_attr
                )
788
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
789 790 791 792 793 794 795 796 797

    def clear_dist_info_for_program(self):
        self._dist_tensors_for_program.clear()
        self._dist_ops_for_program.clear()

    def clear_dist_info_for_graph(self):
        self._dist_tensors_for_graph.clear()
        self._dist_ops_for_graph.clear()

798 799 800 801 802
    def copy_dist_attr_from_program_to_graph(self):
        for node in self.serial_ordered_nodes:
            if node.is_var() and node.var() is not None:
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
803 804 805 806 807 808 809 810 811 812 813
                cur_dist_tensor = self._dist_tensors_for_program.get(
                    tensor_id, None
                )
                if cur_dist_tensor is not None:
                    cur_tensor_id = tensor_id
                else:
                    cur_tensor_id = self._tensor_original_id_to_id[tensor_id]
                    cur_dist_tensor = self._dist_tensors_for_program.get(
                        cur_tensor_id, None
                    )
                dist_tensor = cur_dist_tensor
814 815 816
                assert (
                    dist_tensor is not None
                ), "Tensor must have a distributed tensor after the initialization for program."
817
                serial_tensor_node_id = _node_id(node)
818 819 820
                new_dist_tensor = DistributedTensor(
                    dist_tensor.serial_tensor, dist_tensor.dist_attr
                )
821
                self._dist_tensors_for_graph[
822 823
                    serial_tensor_node_id
                ] = new_dist_tensor
824 825 826
            if node.is_op() and node.op() is not None:
                dist_op = None
                op_id = node.node.original_desc_id()
827 828 829 830 831 832 833 834 835
                cur_dist_op = self._dist_ops_for_program.get(op_id, None)
                if cur_dist_op is not None:
                    cur_op_id = op_id
                else:
                    cur_op_id = self._op_original_id_to_id[op_id]
                    cur_dist_op = self._dist_ops_for_program.get(
                        cur_op_id, None
                    )
                dist_op = cur_dist_op
836 837 838
                assert (
                    dist_op is not None
                ), "Operator must have a distributed operator after the initialization for program."
839
                serial_op_node_id = _node_id(node)
840 841 842
                new_dist_op = DistributedOperator(
                    dist_op.serial_op, dist_op.dist_attr
                )
843 844
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

845
    def copy_dist_attr_from_graph_to_program(self):
846 847 848
        assert (
            self._is_initialized
        ), "Both program and graph must be initialized."
849
        updated_tensors = {}
850 851
        # all_nodes = self._serial_graph.all_nodes()
        all_nodes = self._serial_ordered_nodes
852 853
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
854
                tensor_id = self._node_id_to_tensor_id[_node_id(node)]
855
                updated = updated_tensors.get(tensor_id, False)
856 857
                # If a var has multiples var nodes in graph, only use the first one for now
                if not updated:
858 859 860
                    tensor_dist_attr_for_graph = (
                        self.get_tensor_dist_attr_for_graph(node)
                    )
861
                    dist_tensor_for_program = self._dist_tensors_for_program[
862 863 864 865 866
                        tensor_id
                    ]
                    dist_tensor_for_program.dist_attr = (
                        tensor_dist_attr_for_graph
                    )
867
                    updated_tensors[tensor_id] = True
868
            if node.is_op() and node.op() is not None:
869
                op_id = self._node_id_to_op_id[_node_id(node)]
870 871 872
                op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
                dist_op_for_program = self._dist_ops_for_program[op_id]
                dist_op_for_program.dist_attr = op_dist_attr_for_graph
873
        # TODO: the completion algorithm will skipped orphan tensors,
874 875 876
        # here we just set there process_mesh to the first one.
        for orphan_node in self._serial_orphan_tensor_nodes:
            serial_tensor_id = orphan_node.var().id()
877
            dist_tensor = self._dist_tensors_for_program.get(
878 879
                serial_tensor_id, None
            )
880 881 882 883 884
            if dist_tensor:
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
            else:
                serial_tensor_id = orphan_node.var().original_id()
                dist_tensor = self._dist_tensors_for_program.get(
885 886
                    serial_tensor_id, None
                )
887
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
888 889 890 891 892

    def amend_dist_attr_for_program(self):
        for dist_tensor in self._dist_tensors_for_program.values():
            serial_tensor = dist_tensor.serial_tensor
            dist_attr = dist_tensor.dist_attr
Z
zhaoyingli 已提交
893
            if serial_tensor.type in __no_shape_var_type__:
894 895 896 897
                tensor_shape = []
            else:
                tensor_shape = serial_tensor.shape
            dims_mapping = dist_attr.dims_mapping
898 899
            process_mesh_shape = dist_attr.process_mesh.shape
            process_mesh_processes = dist_attr.process_mesh.process_ids
900 901 902
            # If the dimension of tensor is less than the sharding dimension of process mesh,
            # we just amend the dimension mapping to -1. (Is this really OK?)
            for i in range(len(tensor_shape)):
903 904 905 906 907
                if (
                    dims_mapping[i] != -1
                    and tensor_shape[i] > 0
                    and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]
                ):
908
                    dims_mapping[i] = -1
909 910
                if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
                    dims_mapping[i] = -1
911
            dist_attr.dims_mapping = dims_mapping
912 913 914 915

        for dist_op in self._dist_ops_for_program.values():
            serial_op = dist_op.serial_op
            dist_attr = dist_op.dist_attr
916 917
            process_mesh_shape = dist_attr.process_mesh.shape
            process_mesh_processes = dist_attr.process_mesh.process_ids
918 919 920 921
            for arg_name in serial_op.input_arg_names:
                if dist_op.get_serial_input(arg_name) is None:
                    tensor_shape = []
                else:
922 923
                    if (
                        dist_op.get_serial_input(arg_name).type
Z
zhaoyingli 已提交
924
                        in __no_shape_var_type__
925
                    ):
926 927 928 929 930 931 932
                        tensor_shape = []
                    else:
                        tensor_shape = dist_op.get_serial_input(arg_name).shape
                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
933 934 935 936 937 938
                    if (
                        dims_mapping[i] != -1
                        and tensor_shape[i] > 0
                        and process_mesh_shape[dims_mapping[i]]
                        > tensor_shape[i]
                    ):
939
                        dims_mapping[i] = -1
940 941 942 943
                    if (
                        dims_mapping[i] != -1
                        and len(process_mesh_processes) == 1
                    ):
944
                        dims_mapping[i] = -1
945
                dist_attr.set_input_dims_mapping(arg_name, dims_mapping)
946
            for arg_name in serial_op.output_arg_names:
947 948
                if (
                    dist_op.get_serial_output(arg_name).type
Z
zhaoyingli 已提交
949
                    in __no_shape_var_type__
950
                ):
951 952 953 954 955 956 957
                    tensor_shape = []
                else:
                    tensor_shape = dist_op.get_serial_output(arg_name).shape
                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
958 959 960 961 962 963
                    if (
                        dims_mapping[i] != -1
                        and tensor_shape[i] > 0
                        and process_mesh_shape[dims_mapping[i]]
                        > tensor_shape[i]
                    ):
964
                        dims_mapping[i] = -1
965 966 967 968
                    if (
                        dims_mapping[i] != -1
                        and len(process_mesh_processes) == 1
                    ):
969
                        dims_mapping[i] = -1
970
                dist_attr.set_output_dims_mapping(arg_name, dims_mapping)
971 972 973
            if len(process_mesh_processes) == 1:
                dist_op.dist_attr.impl_type = "default"
                dist_op.dist_attr.impl_idx = 0
974 975

    def validate_dist_attr_for_program(self):
976
        if not self._is_initialized:
977 978 979
            assert (
                False
            ), "Program must be initialized before validating its distributed attributes"
980
        for block in self.serial_main_program.blocks:
981 982
            for tensor in block.vars.values():
                dist_tensor = self.get_dist_tensor_for_program(tensor)
983 984 985 986 987 988 989 990 991 992 993
                assert (
                    dist_tensor is not None
                ), "Tensor {} does not have a distributed attribute.".format(
                    dist_tensor.serial_tensor.name
                )
                if (dist_tensor is not None) and (
                    not dist_tensor.validate_dist_attr()
                ):
                    assert (
                        False
                    ), "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
C
caozhou 已提交
994 995 996
                        dist_tensor.serial_tensor.name,
                        dist_tensor.serial_tensor.desc.id(),
                        dist_tensor.serial_tensor.desc.original_id(),
997 998
                        dist_tensor.dist_attr,
                    )
999 1000
            for op in block.ops:
                dist_op = self.get_dist_op_for_program(op)
1001 1002 1003 1004 1005
                assert (
                    dist_op is not None
                ), "Operator {} does not have a distributed attribute.".format(
                    dist_op.serial_op.type
                )
1006
                if (dist_op is not None) and (not dist_op.validate_dist_attr()):
1007 1008 1009 1010 1011 1012 1013 1014
                    assert (
                        False
                    ), "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
                        dist_op.serial_op.type,
                        dist_op.serial_op.desc.id(),
                        dist_op.serial_op.desc.original_id(),
                        dist_op.dist_attr,
                    )
1015 1016
        return True

Z
zhaoyingli 已提交
1017 1018 1019 1020 1021
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
1022
            if k in [
1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
                "_original_serial_main_program",
                "_original_serial_startup_program",
                "_serial_main_program",
                "_serial_startup_program",
                "_serial_graph",
                "_dist_main_programs",
                "_dist_startup_programs",
                "_serial_ordered_nodes",
                "_serial_ordered_tensor_nodes",
                "_serial_ordered_op_nodes",
                "_original_serial_loss",
                "_original_serial_feed_vars",
                "_original_serial_fetch_vars",
                "_serial_loss",
                "_serial_feed_vars",
                "_serial_fetch_vars",
                "_serial_optimizer",
                "_backup_serial_main_program_stack",
                "_backup_serial_startup_program_stack",
                "_pass_context",
1043
                "_tensor_nodes_with_same_name",
1044
            ]:
Z
zhaoyingli 已提交
1045 1046 1047
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
1048 1049 1050 1051

        # update dist tensor's dist_context
        for key in result._dist_tensors_for_program.keys():
            result._dist_tensors_for_program[key]._dist_context = result
Z
zhaoyingli 已提交
1052 1053
        return result

1054 1055 1056 1057 1058 1059 1060 1061 1062

class DistributedOperatorContext:
    """
    DistributedOperatorContext is used to create a dist op desc in Program.
    Every time to create a new dist op, the context should be updated for it accordingly.
    """

    def __init__(self):
        self._dst_main_program = None
1063
        self._main_block = None
1064
        self._dst_startup_program = None
1065
        self._startup_block = None
1066 1067
        self._cur_src_op = None
        self._cur_dist_attr = None
1068
        self.grad_op_id_to_op_id = {}
1069
        self.grad_var_to_var = defaultdict(dict)
1070
        self._work_block = None
1071
        self.already_init_sync_vars = set()
1072 1073
        self.varname_mapping = None
        self.rank_id = None
1074 1075 1076
        # NOTE Support correct parallelism for high-order differential model.
        # by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True,
        # it means we are in Backward phase.
C
chenxujun 已提交
1077
        # And the final solution should be revise high-order differential logic for these two phases in future.
1078
        self._exceed_backward_init_op = False
1079

Z
zhaoyingli 已提交
1080 1081 1082 1083 1084
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
1085
            if k in [
1086 1087 1088 1089 1090 1091
                "_dst_main_program",
                "_dst_startup_program",
                "_cur_src_op",
                "_work_block",
                "_main_block",
                "_startup_block",
1092
            ]:
Z
zhaoyingli 已提交
1093 1094 1095 1096 1097
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

1098 1099
    @property
    def dst_main_program(self):
1100 1101
        return self._dst_main_program

1102 1103 1104 1105
    @dst_main_program.setter
    def dst_main_program(self, prog):
        self._dst_main_program = prog
        self._main_block = prog.blocks[0]
1106

1107 1108 1109
    @property
    def main_block(self):
        return self._main_block
1110

1111 1112 1113
    @property
    def dst_startup_program(self):
        return self._dst_startup_program
1114

1115 1116 1117 1118
    @dst_startup_program.setter
    def dst_startup_program(self, prog):
        self._dst_startup_program = prog
        self._startup_block = prog.blocks[0]
1119

1120 1121 1122
    @property
    def startup_block(self):
        return self._startup_block
1123

1124 1125 1126 1127
    @property
    def work_block(self):
        assert self._work_block is not None
        return self._work_block
1128

1129 1130 1131 1132
    @work_block.setter
    def work_block(self, block):
        assert block is not None
        self._work_block = block
1133

1134 1135 1136
    @property
    def cur_src_op(self):
        assert self._cur_src_op is not None
1137 1138
        return self._cur_src_op

1139 1140 1141
    def in_backward_phase(self):
        return self._exceed_backward_init_op

1142
    def prepare_context(self, src_op):
1143

1144
        self._cur_src_op = src_op
1145

1146 1147 1148
        if is_loss_grad_op(src_op):
            self._exceed_backward_init_op = True

1149 1150 1151 1152 1153
        # build input varname mapping
        kinputs = {}
        for input_name in src_op.desc.input_names():
            varnames = []
            for varname in src_op.desc.input(input_name):
1154 1155
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
1156 1157 1158 1159 1160 1161 1162
            kinputs[input_name] = varnames

        # build output varname mapping
        koutputs = {}
        for output_name in src_op.desc.output_names():
            varnames = []
            for varname in src_op.desc.output(output_name):
1163 1164
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
1165 1166 1167
            koutputs[output_name] = varnames

        return kinputs, koutputs
1168 1169


1170
class BlockState:
1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
    def __init__(self):
        self.nblock = 0
        self.forward_indices = []
        self.backward_indices = []
        self.backward_to_forward_index_map = {}

    def parse_forward_blocks(self, program):

        while program.current_block_idx != 0:
            program._rollback()

        assert program.current_block_idx == 0

        for idx, block in enumerate(program.blocks):

            assert idx == block.idx, "index doesn't match"
1187 1188 1189 1190 1191
            assert (
                block.forward_block_idx == -1
            ), "forward_block_idx of forward block [{}] is not [{}]".format(
                idx, block.forward_block_idx
            )
1192 1193 1194 1195 1196 1197 1198 1199
            self.forward_indices.append(idx)
            self.nblock += 1

        assert self.nblock >= 1

    def parse_backward_blocks(self, program):

        assert 0 in self.forward_indices, "forward block idx are{}".format(
1200 1201
            self.forward_indices
        )
1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215
        self.backward_to_forward_index_map[0] = 0

        for idx, block in enumerate(program.blocks):

            if idx < len(self.forward_indices):
                continue

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx in self.forward_indices
            self.backward_indices.append(idx)
            self.backward_to_forward_index_map[idx] = block.forward_block_idx
            self.nblock += 1

        assert self.nblock == len(program.blocks)