dist_context.py 45.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
from collections import defaultdict
from paddle.fluid import framework
18
from paddle.fluid.framework import set_flags
19
from paddle.fluid import core
20
from paddle.distributed.passes import PassContext
21 22 23
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
24
from .utils import _copy_dist_attr_to_cpp
Z
zhaoyingli 已提交
25
from .utils import is_loss_grad_op, __no_shape_var_type__
26

27

28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44
# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None


def get_default_distributed_context():
    global _g_default_distributed_context
    if _g_default_distributed_context is None:
        dist_context = DistributedContext()
        set_default_distributed_context(dist_context)
    return _g_default_distributed_context


def set_default_distributed_context(dist_context):
    global _g_default_distributed_context
    _g_default_distributed_context = dist_context


45 46 47 48
def _node_id(node):
    return (node.node.graph_id(), node.node.id())


49 50 51 52 53 54
class DistributedContext:
    """
    DistributedContext is used to collect related distributed information for program and graph.
    One auto-parallel run should use its own DistributedContext to avoid interfering other run.
    """

55 56 57 58 59 60 61 62 63 64 65
    def __init__(
        self,
        serial_main_prog=None,
        serial_startup_prog=None,
        serial_optimizer=None,
        serial_loss=None,
        feed_vars={},
        fetch_vars={},
        cluster=None,
        strategy=None,
    ):
66 67 68
        # Data members related to original programs (unchanged)
        self._original_serial_main_program = serial_main_prog
        self._original_serial_startup_program = serial_startup_prog
69
        self._original_serial_optimizer = serial_optimizer
70
        self._original_serial_loss = serial_loss
71 72
        self._original_serial_feed_vars = feed_vars
        self._original_serial_fetch_vars = fetch_vars
73 74 75 76

        # Data members related to programs (changed)
        self._serial_main_program = None
        self._serial_startup_program = None
77 78 79 80
        self._serial_loss = None
        self._serial_optimizer = None
        self._serial_feed_vars = {}
        self._serial_fetch_vars = {}
81
        self._lr_optimizer = None  # record the optimzier holding lr_scheduler
82 83

        # Data members related to the program
84 85
        self._dist_tensors_for_program = {}
        self._dist_ops_for_program = {}
86 87

        # Data members related to the graph
88
        self._serial_graph = None
89 90
        self._dist_tensors_for_graph = {}
        self._dist_ops_for_graph = {}
91 92
        self._node_id_to_tensor_id = {}
        self._node_id_to_op_id = {}
93

94
        # Data members related to the distributed programs
95
        # Distributed programs
96 97
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
98 99
        self._dist_op_context = DistributedOperatorContext()
        self._process_meshes = []
100

101
        self._cluster = cluster
102 103 104 105
        self._strategy = strategy

        # Pass Context
        self._pass_context = PassContext()
106
        self._block_state = BlockState()
107 108 109 110 111 112 113 114

        # Other data members
        self._serial_ordered_tensor_nodes = []
        self._serial_ordered_op_nodes = []
        self._serial_ordered_nodes = []
        # self._tensor_id_to_tensor_node_ids = {}

        self._is_initialized = False
115
        # TODO: need a better way to remove the following flag
116 117 118 119 120 121 122
        self._need_copy_dist_attr_to_graph = False
        self._backup_pass_context_stack = []
        self._backup_block_state_stack = []
        self._backup_dist_tensors_for_program_stack = []
        self._backup_dist_ops_for_program_stack = []
        self._backup_serial_main_program_stack = []
        self._backup_serial_startup_program_stack = []
123

124 125 126
        # flag whether scale gradient with dp size
        self._gradient_scale = True

127 128 129
        # A flag indicates whether the used parallelism is data parallel
        self._data_parallel = False

130
    @property
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145
    def serial_main_program(self):
        return self._serial_main_program

    @property
    def serial_startup_program(self):
        return self._serial_startup_program

    @property
    def serial_loss(self):
        return self._serial_loss

    @property
    def serial_optimizer(self):
        return self._serial_optimizer

146 147 148 149 150 151 152
    @property
    def serial_feed_vars(self):
        return self._serial_feed_vars

    @property
    def serial_fetch_vars(self):
        return self._serial_fetch_vars
153

154 155 156 157 158 159 160 161 162 163 164 165
    @property
    def dist_main_programs(self):
        return self._dist_main_programs

    @property
    def dist_startup_programs(self):
        return self._dist_startup_programs

    @property
    def cluster(self):
        return self._cluster

166 167 168 169
    @property
    def strategy(self):
        return self._strategy

170 171 172 173
    @property
    def serial_graph(self):
        return self._serial_graph

174 175 176 177
    @property
    def serial_ordered_nodes(self):
        return self._serial_ordered_nodes

178 179 180 181
    @property
    def process_meshes(self):
        return self._process_meshes

182 183 184 185
    @property
    def pass_context(self):
        return self._pass_context

186 187 188 189
    @property
    def dist_op_context(self):
        return self._dist_op_context

190 191 192 193
    @property
    def block_state(self):
        return self._block_state

194
    @property
195
    def has_annotation(self):
196
        return len(self._dist_tensors_for_program) or len(
197 198
            self._dist_ops_for_program
        )
199

200 201 202 203 204 205 206 207
    @property
    def gradient_scale(self):
        return self._gradient_scale

    @gradient_scale.setter
    def gradient_scale(self, gs):
        self._gradient_scale = gs

208 209 210 211 212 213 214 215
    @property
    def data_parallel(self):
        return self._data_parallel

    @data_parallel.setter
    def data_parallel(self, dp):
        self._data_parallel = dp

216 217
    def _backup_serial_info(self, mode):
        self._backup_serial_main_program_stack.append(
218 219
            self._serial_main_program.clone()
        )
220
        self._backup_serial_startup_program_stack.append(
221 222 223 224 225
            self._serial_startup_program.clone()
        )
        self._backup_pass_context_stack.append(
            copy.deepcopy(self._pass_context)
        )
226 227 228 229
        self._backup_block_state_stack.append(copy.deepcopy(self._block_state))

    def _backup_dist_info(self, mode):
        self._backup_dist_tensors_for_program_stack.append(
230 231
            copy.deepcopy(self._dist_tensors_for_program)
        )
232
        self._backup_dist_ops_for_program_stack.append(
233 234
            copy.deepcopy(self._dist_ops_for_program)
        )
235 236 237 238 239 240 241 242

    def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
        # Use this function carefully
        if serial:
            self._backup_serial_info(serial_mode)
        if dist:
            self._backup_dist_info(dist_mode)

243
    def _restore_serial_loss(self):
244 245
        if self._original_serial_loss:
            if isinstance(self._original_serial_loss, list):
246 247 248 249 250
                if len(self._original_serial_loss) == 1:
                    loss = self._original_serial_loss[0]
                    block_idx = loss.block.idx
                    var_name = loss.name
                    var = self._serial_main_program.blocks[
251 252
                        block_idx
                    ]._var_recursive(var_name)
253 254 255 256 257
                    self._serial_loss = var
                elif len(self._original_serial_loss) == 0:
                    self._serial_loss = []
                else:
                    raise ValueError("multi loss vars are not supported.")
258
            else:
259 260 261
                block_idx = self._original_serial_loss.block.idx
                var_name = self._original_serial_loss.name
                var = self._serial_main_program.blocks[
262 263
                    block_idx
                ]._var_recursive(var_name)
264 265
                self._serial_loss = var

266
    def _restore_serial_feed_vars(self):
267 268 269 270 271 272
        for key, var_list in self._original_serial_feed_vars.items():
            new_var_list = []
            for var in var_list:
                block_idx = var.block.idx
                var_name = var.name
                var = self._serial_main_program.blocks[
273 274
                    block_idx
                ]._var_recursive(var_name)
275 276 277
                new_var_list.append(var)
            self._serial_feed_vars[key] = new_var_list

278
    def _restore_serial_fetch_vars(self):
279 280
        for key, var_list in self._original_serial_fetch_vars.items():
            new_var_list = []
281 282 283 284 285 286 287 288
            # metrics is a list of list
            if key == "metrics":
                for inner_var_list in var_list:
                    new_inner_var_list = []
                    for var in inner_var_list:
                        block_idx = var.block.idx
                        var_name = var.name
                        var = self._serial_main_program.blocks[
289 290
                            block_idx
                        ]._var_recursive(var_name)
291 292 293 294 295 296 297
                        new_inner_var_list.append(var)
                    new_var_list.append(new_inner_var_list)
            else:
                for var in var_list:
                    block_idx = var.block.idx
                    var_name = var.name
                    var = self._serial_main_program.blocks[
298 299
                        block_idx
                    ]._var_recursive(var_name)
300
                    new_var_list.append(var)
301 302
            self._serial_fetch_vars[key] = new_var_list

303 304
    def _restore_serial_info(self, mode="to_backup"):
        if mode == "to_backup":
305 306
            self._serial_main_program = (
                self._backup_serial_main_program_stack.pop()
307
            )
308 309
            self._serial_startup_program = (
                self._backup_serial_startup_program_stack.pop()
310 311 312 313
            )
        elif mode == "to_original":
            assert self._original_serial_main_program is not None
            assert self._original_serial_startup_program is not None
314 315
            self._serial_main_program = (
                self._original_serial_main_program.clone()
316
            )
317 318
            self._serial_startup_program = (
                self._original_serial_startup_program.clone()
319 320 321 322 323 324
            )

        self._restore_serial_loss()
        self._restore_serial_feed_vars()
        self._restore_serial_fetch_vars()
        self._serial_optimizer = self._original_serial_optimizer
325 326 327 328 329
        self._pass_context = self._backup_pass_context_stack.pop()
        self._block_state = self._backup_block_state_stack.pop()

    def _restore_dist_info(self, mode="to_backup"):
        if mode == "to_backup":
330 331
            self._dist_tensors_for_program = (
                self._backup_dist_tensors_for_program_stack.pop()
332
            )
333 334
            self._dist_ops_for_program = (
                self._backup_dist_ops_for_program_stack.pop()
335 336 337 338 339
            )
        elif mode == "to_original":
            assert self._original_dist_tensors_for_program
            assert self._original_dist_ops_for_program
            self._dist_tensors_for_program = copy.deepcopy(
340 341
                self._original_dist_tensors_for_program
            )
342
            self._dist_ops_for_program = copy.deepcopy(
343 344
                self._original_dist_ops_for_program
            )
345 346
        elif mode == "to_default":
            new_tensors_ids = []
347 348 349 350
            for (
                tensor_id,
                dist_tensor,
            ) in self._dist_tensors_for_program.items():
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
                if tensor_id in self._tensors_ids:
                    dist_tensor.dist_attr.reset()
                else:
                    new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                if op_id in self._ops_ids:
                    dist_op.dist_attr.reset()
                else:
                    new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        else:
            new_tensors_ids = []
367 368 369 370
            for (
                tensor_id,
                dist_tensor,
            ) in self._dist_tensors_for_program.items():
371 372 373 374 375 376 377 378 379 380 381 382 383 384
                new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
        self._dist_op_context = DistributedOperatorContext()
        self._need_copy_dist_attr_to_graph = True
        self._process_meshes = []

385 386 387 388 389 390 391
    def _restore(
        self,
        serial=True,
        serial_mode="to_backup",
        dist=True,
        dist_mode="to_backup",
    ):
392 393 394 395 396 397
        # Use this function carefully
        if serial:
            self._restore_serial_info(serial_mode)
        if dist:
            self._restore_dist_info(dist_mode)

398
    def initialize(self, with_graph=True, with_cpp=False):
399 400
        if not self._is_initialized:
            if not self._serial_main_program:
401
                if self._original_serial_main_program:
402 403
                    self._serial_main_program = (
                        self._original_serial_main_program.clone()
404
                    )
405
            if not self._serial_startup_program:
406
                if self._original_serial_startup_program:
407 408
                    self._serial_startup_program = (
                        self._original_serial_startup_program.clone()
409
                    )
410
            if not self._serial_loss:
411
                self._restore_serial_loss()
412 413 414
            if not self._serial_optimizer:
                self._serial_optimizer = self._original_serial_optimizer
            if not self._serial_feed_vars:
415
                self._restore_serial_feed_vars()
416
            if not self._serial_fetch_vars:
417
                self._restore_serial_fetch_vars()
418

419
            self._init_dist_attr_for_program()
420 421
            # Backup the original distributed information for later restore
            self._original_dist_tensors_for_program = copy.deepcopy(
422 423
                self._dist_tensors_for_program
            )
424
            self._original_dist_ops_for_program = copy.deepcopy(
425 426
                self._dist_ops_for_program
            )
427 428 429
            self._tensors_ids = list(self._dist_tensors_for_program.keys())
            self._ops_ids = list(self._dist_ops_for_program.keys())
            self._is_initialized = True
430

431 432 433 434
            # TODO: This will be removed in the future
            if with_cpp:
                _copy_dist_attr_to_cpp(self)

435 436 437
            if with_graph:
                set_flags({"FLAGS_convert_all_blocks": True})
                self._serial_graph = framework.IrGraph(
438 439
                    core.Graph(self._serial_main_program.desc)
                )
440 441 442 443
                self._init_dist_attr_for_graph()
                self._need_copy_dist_attr_to_graph = False

        if self._need_copy_dist_attr_to_graph and with_graph:
444
            self.copy_dist_attr_from_program_to_graph()
445

446
    def add_process_mesh(self, process_mesh):
447 448 449
        assert isinstance(
            process_mesh, ProcessMesh
        ), 'The type of dim_mapping must be ProcessMesh.'
450 451 452 453 454
        if process_mesh not in self.process_meshes:
            self._process_meshes.append(process_mesh)

    def add_dist_tensor_for_program(self, dist_tensor):
        inner_serial_tensor = dist_tensor.serial_tensor
455
        inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
456 457 458 459
        self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor

    def add_dist_op_for_program(self, dist_op):
        inner_serial_op = dist_op.serial_op
460
        inner_serial_op_id = inner_serial_op.desc.original_id()
461 462 463 464
        self._dist_ops_for_program[inner_serial_op_id] = dist_op

    def get_dist_tensor_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
465 466 467 468 469
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor
        else:
            serial_tensor_id = serial_tensor.desc.original_id()
470
            dist_tensor = self._dist_tensors_for_program.get(
471 472
                serial_tensor_id, None
            )
473 474 475 476
            if dist_tensor:
                return dist_tensor
            else:
                return None
477 478

    def get_dist_tensor_for_graph(self, serial_tensor_node):
479
        serial_tensor_node_id = _node_id(serial_tensor_node)
480 481
        return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)

482 483 484 485 486 487 488 489 490 491 492 493
    def get_dist_op_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op
        else:
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op
            else:
                return None
494

495 496 497 498 499
    def del_dist_op_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        if self._dist_ops_for_program.get(serial_tensor_id, None):
            del self._dist_ops_for_program[serial_tensor_id]

500
    def get_dist_op_for_graph(self, serial_op_node):
501
        serial_op_node_id = _node_id(serial_op_node)
502
        return self._dist_ops_for_graph.get(serial_op_node_id, None)
503 504 505 506 507 508 509

    def get_tensor_dist_attr_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
510
            serial_tensor_id = serial_tensor.desc.original_id()
511
            dist_tensor = self._dist_tensors_for_program.get(
512 513
                serial_tensor_id, None
            )
514 515 516 517
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
518

519 520 521 522 523 524 525
    def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
        dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

526 527 528 529 530
    def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
        dist_tensor = DistributedTensor(serial_tensor, dist_attr)
        self.add_dist_tensor_for_program(dist_tensor)

    def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
531
        serial_tensor_node_id = _node_id(serial_tensor_node)
532 533 534
        dist_tensor = self._dist_tensors_for_graph.get(
            serial_tensor_node_id, None
        )
535 536 537 538 539 540 541 542 543 544 545
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

    def get_op_dist_attr_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
546 547 548 549 550 551
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
552

553 554 555 556 557 558 559
    def get_op_dist_attr_for_program_with_id(self, op_id):
        dist_op = self._dist_ops_for_program.get(op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

560 561 562 563 564
    def set_op_dist_attr_for_program(self, serial_op, dist_attr):
        dist_op = DistributedOperator(serial_op, dist_attr)
        self.add_dist_op_for_program(dist_op)

    def get_op_dist_attr_for_graph(self, serial_op_node):
565
        serial_op_node_id = _node_id(serial_op_node)
566 567 568 569 570 571
        dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

572 573
    def get_dist_attr_for_graph(self, serial_node):
        if serial_node.is_var() and serial_node.var() is not None:
574
            serial_tensor_node_id = _node_id(serial_node)
575
            dist_tensor = self._dist_tensors_for_graph.get(
576 577
                serial_tensor_node_id, None
            )
578 579 580 581 582
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
        if serial_node.is_op() and serial_node.op() is not None:
583
            serial_op_node_id = _node_id(serial_node)
584 585 586 587 588 589
            dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
        return None
590

591
    def _init_dist_attr_for_program(self, no_default=False):
592
        # Copy the dist tensors and dist ops annotated by users from the default context
593 594 595 596 597
        if not no_default:
            default_ctx = get_default_distributed_context()
            self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
        else:
            default_ctx = self
598 599
        # Copy the data parallel flag from the default context
        self._data_parallel = default_ctx.data_parallel
600
        for block in self._serial_main_program.blocks:
601 602 603
            for tensor in block.vars.values():
                # Copy the distributed tensors in the default context
                default_dist_tensor = default_ctx.get_dist_tensor_for_program(
604 605
                    tensor
                )
606
                if default_dist_tensor and default_ctx is not self:
607 608 609 610 611
                    dist_tensor = DistributedTensor(tensor)
                    dist_tensor.dist_attr = copy.deepcopy(
                        default_dist_tensor.dist_attr
                    )
                    self.add_dist_tensor_for_program(dist_tensor)
612 613 614 615 616 617 618 619
                current_dist_tensor = self.get_dist_tensor_for_program(tensor)
                if current_dist_tensor is None:
                    dist_tensor = DistributedTensor(tensor)
                    self.add_dist_tensor_for_program(dist_tensor)
            for op in block.ops:
                # Copy the distributed operators in the default context
                default_dist_op = default_ctx.get_dist_op_for_program(op)
                if default_dist_op and default_ctx is not self:
620 621 622
                    dist_op = DistributedOperator(op)
                    dist_op.dist_attr = copy.deepcopy(default_dist_op.dist_attr)
                    self.add_dist_op_for_program(dist_op)
623 624 625 626
                current_dist_op = self.get_dist_op_for_program(op)
                if current_dist_op is None:
                    dist_op = DistributedOperator(op)
                    self.add_dist_op_for_program(dist_op)
627
        self._original_dist_tensors_for_program = copy.deepcopy(
628 629
            self._dist_tensors_for_program
        )
630
        self._original_dist_ops_for_program = copy.deepcopy(
631 632
            self._dist_ops_for_program
        )
633

634
    def _order_nodes_by_program_order(self):
635 636
        def _contains(nodes, target_node):
            for node in nodes:
637
                if _node_id(node) == _node_id(target_node):
638 639 640
                    return True
            return False

641 642 643 644 645 646
        serial_ordered_tensor_nodes = []
        serial_ordered_op_nodes = []
        all_nodes = []
        for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
            for node in graph.all_nodes():
                all_nodes.append(node)
647 648
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
649
                serial_ordered_tensor_nodes.append(node)
650
            if node.is_op() and node.op() is not None:
651 652
                serial_ordered_op_nodes.append(node)
        serial_ordered_tensor_nodes.sort(
653 654
            key=lambda node: node.node.original_desc_id()
        )
655
        serial_ordered_op_nodes.sort(
656 657
            key=lambda node: node.node.original_desc_id()
        )
658
        num_nodes_before = len(serial_ordered_tensor_nodes) + len(
659 660
            serial_ordered_op_nodes
        )
661 662 663

        new_serial_ordered_tensor_nodes = []
        new_serial_ordered_op_nodes = []
664
        new_serial_ordered_nodes = []
665
        for op_node in serial_ordered_op_nodes:
666 667
            tensor_nodes = []
            for tensor_node in op_node.inputs:
668 669 670 671 672
                if (
                    tensor_node.is_var()
                    and tensor_node.var() is not None
                    and not _contains(new_serial_ordered_nodes, tensor_node)
                ):
673
                    tensor_nodes.append(tensor_node)
674
                    new_serial_ordered_tensor_nodes.append(tensor_node)
675
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
676 677
            new_serial_ordered_nodes.extend(tensor_nodes)
            new_serial_ordered_nodes.append(op_node)
678
            new_serial_ordered_op_nodes.append(op_node)
679 680
            tensor_nodes = []
            for tensor_node in op_node.outputs:
681 682 683 684 685
                if (
                    tensor_node.is_var()
                    and tensor_node.var() is not None
                    and not _contains(new_serial_ordered_nodes, tensor_node)
                ):
686
                    tensor_nodes.append(tensor_node)
687 688
                    new_serial_ordered_tensor_nodes.append(tensor_node)
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
689
            new_serial_ordered_nodes.extend(tensor_nodes)
690
        new_serial_ordered_tensor_nodes.sort(
691 692
            key=lambda node: node.node.original_desc_id()
        )
693
        new_serial_ordered_op_nodes.sort(
694 695
            key=lambda node: node.node.original_desc_id()
        )
696 697
        self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
        self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
698
        self._serial_ordered_nodes = new_serial_ordered_nodes
699
        assert len(self._serial_ordered_nodes) == len(
700 701
            self._serial_ordered_tensor_nodes
        ) + len(self._serial_ordered_op_nodes)
702 703 704 705 706 707 708 709
        self._serial_orphan_tensor_nodes = []
        for tensor_node in serial_ordered_tensor_nodes:
            if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
                self._serial_orphan_tensor_nodes.append(tensor_node)
        if len(self._serial_ordered_nodes) != num_nodes_before:
            print(
                "WARNING: there are some orphan tensors or ops which are not used in the execution."
            )
710

711 712 713
    def _init_dist_attr_for_graph(self):
        # Convert program to graph and initialize the distributed attributes
        self._order_nodes_by_program_order()
714
        for node in self.serial_ordered_nodes:
715
            if node.is_var() and node.var() is not None:
716 717
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
718 719 720 721 722 723 724 725 726
                for (
                    cur_tensor_id,
                    cur_dist_tensor,
                ) in self._dist_tensors_for_program.items():
                    if (
                        tensor_id == cur_tensor_id
                        or tensor_id
                        == cur_dist_tensor.serial_tensor.desc.original_id()
                    ):
727
                        dist_tensor = cur_dist_tensor
728 729 730 731 732 733
                        self._node_id_to_tensor_id[
                            _node_id(node)
                        ] = cur_tensor_id
                assert (
                    dist_tensor is not None
                ), "Tensor must have a distributed tensor after the initialization for program."
734
                serial_tensor_node_id = _node_id(node)
735 736 737
                new_dist_tensor = DistributedTensor(
                    dist_tensor.serial_tensor, dist_tensor.dist_attr
                )
738
                self._dist_tensors_for_graph[
739 740
                    serial_tensor_node_id
                ] = new_dist_tensor
741
            if node.is_op() and node.op() is not None:
742 743
                dist_op = None
                op_id = node.node.original_desc_id()
744 745 746 747 748 749 750 751
                for (
                    cur_op_id,
                    cur_dist_op,
                ) in self._dist_ops_for_program.items():
                    if (
                        op_id == cur_op_id
                        or op_id == cur_dist_op.serial_op.desc.original_id()
                    ):
752
                        dist_op = cur_dist_op
753
                        self._node_id_to_op_id[_node_id(node)] = cur_op_id
754 755 756
                assert (
                    dist_op is not None
                ), "Operator must have a distributed operator after the initialization for program."
757
                serial_op_node_id = _node_id(node)
758 759 760
                new_dist_op = DistributedOperator(
                    dist_op.serial_op, dist_op.dist_attr
                )
761
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
762 763 764 765 766 767 768 769 770

    def clear_dist_info_for_program(self):
        self._dist_tensors_for_program.clear()
        self._dist_ops_for_program.clear()

    def clear_dist_info_for_graph(self):
        self._dist_tensors_for_graph.clear()
        self._dist_ops_for_graph.clear()

771 772 773 774 775
    def copy_dist_attr_from_program_to_graph(self):
        for node in self.serial_ordered_nodes:
            if node.is_var() and node.var() is not None:
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
776 777 778 779 780 781 782 783 784
                for (
                    cur_tensor_id,
                    cur_dist_tensor,
                ) in self._dist_tensors_for_program.items():
                    if (
                        tensor_id == cur_tensor_id
                        or tensor_id
                        == cur_dist_tensor.serial_tensor.desc.original_id()
                    ):
785
                        dist_tensor = cur_dist_tensor
786 787 788
                assert (
                    dist_tensor is not None
                ), "Tensor must have a distributed tensor after the initialization for program."
789
                serial_tensor_node_id = _node_id(node)
790 791 792
                new_dist_tensor = DistributedTensor(
                    dist_tensor.serial_tensor, dist_tensor.dist_attr
                )
793
                self._dist_tensors_for_graph[
794 795
                    serial_tensor_node_id
                ] = new_dist_tensor
796 797 798
            if node.is_op() and node.op() is not None:
                dist_op = None
                op_id = node.node.original_desc_id()
799 800 801 802 803 804 805 806
                for (
                    cur_op_id,
                    cur_dist_op,
                ) in self._dist_ops_for_program.items():
                    if (
                        op_id == cur_op_id
                        or op_id == cur_dist_op.serial_op.desc.original_id()
                    ):
807
                        dist_op = cur_dist_op
808 809 810
                assert (
                    dist_op is not None
                ), "Operator must have a distributed operator after the initialization for program."
811
                serial_op_node_id = _node_id(node)
812 813 814
                new_dist_op = DistributedOperator(
                    dist_op.serial_op, dist_op.dist_attr
                )
815 816
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

817
    def copy_dist_attr_from_graph_to_program(self):
818 819 820
        assert (
            self._is_initialized
        ), "Both program and graph must be initialized."
821
        updated_tensors = {}
822 823
        # all_nodes = self._serial_graph.all_nodes()
        all_nodes = self._serial_ordered_nodes
824 825
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
826
                tensor_id = self._node_id_to_tensor_id[_node_id(node)]
827
                updated = updated_tensors.get(tensor_id, False)
828 829
                # If a var has multiples var nodes in graph, only use the first one for now
                if not updated:
830 831 832
                    tensor_dist_attr_for_graph = (
                        self.get_tensor_dist_attr_for_graph(node)
                    )
833
                    dist_tensor_for_program = self._dist_tensors_for_program[
834 835 836 837 838
                        tensor_id
                    ]
                    dist_tensor_for_program.dist_attr = (
                        tensor_dist_attr_for_graph
                    )
839
                    updated_tensors[tensor_id] = True
840
            if node.is_op() and node.op() is not None:
841
                op_id = self._node_id_to_op_id[_node_id(node)]
842 843 844
                op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
                dist_op_for_program = self._dist_ops_for_program[op_id]
                dist_op_for_program.dist_attr = op_dist_attr_for_graph
845
        # TODO: the completion algorithm will skipped orphan tensors,
846 847 848
        # here we just set there process_mesh to the first one.
        for orphan_node in self._serial_orphan_tensor_nodes:
            serial_tensor_id = orphan_node.var().id()
849
            dist_tensor = self._dist_tensors_for_program.get(
850 851
                serial_tensor_id, None
            )
852 853 854 855 856
            if dist_tensor:
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
            else:
                serial_tensor_id = orphan_node.var().original_id()
                dist_tensor = self._dist_tensors_for_program.get(
857 858
                    serial_tensor_id, None
                )
859
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
860 861 862 863 864

    def amend_dist_attr_for_program(self):
        for dist_tensor in self._dist_tensors_for_program.values():
            serial_tensor = dist_tensor.serial_tensor
            dist_attr = dist_tensor.dist_attr
Z
zhaoyingli 已提交
865
            if serial_tensor.type in __no_shape_var_type__:
866 867 868 869 870
                tensor_shape = []
            else:
                tensor_shape = serial_tensor.shape
            dims_mapping = dist_attr.dims_mapping
            process_mesh_shape = dist_attr.process_mesh.topology
871
            process_mesh_processes = dist_attr.process_mesh.processes
872 873 874
            # If the dimension of tensor is less than the sharding dimension of process mesh,
            # we just amend the dimension mapping to -1. (Is this really OK?)
            for i in range(len(tensor_shape)):
875 876 877 878 879
                if (
                    dims_mapping[i] != -1
                    and tensor_shape[i] > 0
                    and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]
                ):
880
                    dims_mapping[i] = -1
881 882
                if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
                    dims_mapping[i] = -1
883 884 885 886

        for dist_op in self._dist_ops_for_program.values():
            serial_op = dist_op.serial_op
            dist_attr = dist_op.dist_attr
887 888
            process_mesh_shape = dist_attr.process_mesh.topology
            process_mesh_processes = dist_attr.process_mesh.processes
889 890 891 892
            for arg_name in serial_op.input_arg_names:
                if dist_op.get_serial_input(arg_name) is None:
                    tensor_shape = []
                else:
893 894
                    if (
                        dist_op.get_serial_input(arg_name).type
Z
zhaoyingli 已提交
895
                        in __no_shape_var_type__
896
                    ):
897 898 899 900 901 902 903
                        tensor_shape = []
                    else:
                        tensor_shape = dist_op.get_serial_input(arg_name).shape
                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
904 905 906 907 908 909
                    if (
                        dims_mapping[i] != -1
                        and tensor_shape[i] > 0
                        and process_mesh_shape[dims_mapping[i]]
                        > tensor_shape[i]
                    ):
910
                        dims_mapping[i] = -1
911 912 913 914
                    if (
                        dims_mapping[i] != -1
                        and len(process_mesh_processes) == 1
                    ):
915
                        dims_mapping[i] = -1
916
            for arg_name in serial_op.output_arg_names:
917 918
                if (
                    dist_op.get_serial_output(arg_name).type
Z
zhaoyingli 已提交
919
                    in __no_shape_var_type__
920
                ):
921 922 923 924 925 926 927
                    tensor_shape = []
                else:
                    tensor_shape = dist_op.get_serial_output(arg_name).shape
                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
928 929 930 931 932 933
                    if (
                        dims_mapping[i] != -1
                        and tensor_shape[i] > 0
                        and process_mesh_shape[dims_mapping[i]]
                        > tensor_shape[i]
                    ):
934
                        dims_mapping[i] = -1
935 936 937 938
                    if (
                        dims_mapping[i] != -1
                        and len(process_mesh_processes) == 1
                    ):
939 940 941 942
                        dims_mapping[i] = -1
            if len(process_mesh_processes) == 1:
                dist_op.dist_attr.impl_type = "default"
                dist_op.dist_attr.impl_idx = 0
943 944

    def validate_dist_attr_for_program(self):
945
        if not self._is_initialized:
946 947 948
            assert (
                False
            ), "Program must be initialized before validating its distributed attributes"
949
        for block in self.serial_main_program.blocks:
950 951
            for tensor in block.vars.values():
                dist_tensor = self.get_dist_tensor_for_program(tensor)
952 953 954 955 956 957 958 959 960 961 962
                assert (
                    dist_tensor is not None
                ), "Tensor {} does not have a distributed attribute.".format(
                    dist_tensor.serial_tensor.name
                )
                if (dist_tensor is not None) and (
                    not dist_tensor.validate_dist_attr()
                ):
                    assert (
                        False
                    ), "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
C
caozhou 已提交
963 964 965
                        dist_tensor.serial_tensor.name,
                        dist_tensor.serial_tensor.desc.id(),
                        dist_tensor.serial_tensor.desc.original_id(),
966 967
                        dist_tensor.dist_attr,
                    )
968 969
            for op in block.ops:
                dist_op = self.get_dist_op_for_program(op)
970 971 972 973 974
                assert (
                    dist_op is not None
                ), "Operator {} does not have a distributed attribute.".format(
                    dist_op.serial_op.type
                )
975
                if (dist_op is not None) and (not dist_op.validate_dist_attr()):
976 977 978 979 980 981 982 983
                    assert (
                        False
                    ), "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
                        dist_op.serial_op.type,
                        dist_op.serial_op.desc.id(),
                        dist_op.serial_op.desc.original_id(),
                        dist_op.dist_attr,
                    )
984 985
        return True

Z
zhaoyingli 已提交
986 987 988 989 990
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
991
            if k in [
992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012
                "_original_serial_main_program",
                "_original_serial_startup_program",
                "_serial_main_program",
                "_serial_startup_program",
                "_serial_graph",
                "_dist_main_programs",
                "_dist_startup_programs",
                "_serial_ordered_nodes",
                "_serial_ordered_tensor_nodes",
                "_serial_ordered_op_nodes",
                "_original_serial_loss",
                "_original_serial_feed_vars",
                "_original_serial_fetch_vars",
                "_serial_loss",
                "_serial_feed_vars",
                "_serial_fetch_vars",
                "_serial_optimizer",
                "_backup_serial_main_program_stack",
                "_backup_serial_startup_program_stack",
                "_pass_context",
            ]:
Z
zhaoyingli 已提交
1013 1014 1015
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
1016 1017 1018 1019

        # update dist tensor's dist_context
        for key in result._dist_tensors_for_program.keys():
            result._dist_tensors_for_program[key]._dist_context = result
Z
zhaoyingli 已提交
1020 1021
        return result

1022 1023 1024 1025 1026 1027 1028 1029 1030

class DistributedOperatorContext:
    """
    DistributedOperatorContext is used to create a dist op desc in Program.
    Every time to create a new dist op, the context should be updated for it accordingly.
    """

    def __init__(self):
        self._dst_main_program = None
1031
        self._main_block = None
1032
        self._dst_startup_program = None
1033
        self._startup_block = None
1034 1035
        self._cur_src_op = None
        self._cur_dist_attr = None
1036
        self.grad_op_id_to_op_id = {}
1037
        self.grad_var_to_var = defaultdict(dict)
1038
        self._work_block = None
1039
        self.already_init_sync_vars = set()
1040 1041
        self.varname_mapping = None
        self.rank_id = None
1042 1043 1044 1045 1046
        # NOTE Support correct parallelism for high-order differential model.
        # by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True,
        # it means we are in Backward phase.
        # And the final sulotion should be revise high-order differential logic for these two phases in future.
        self._exceed_backward_init_op = False
1047

Z
zhaoyingli 已提交
1048 1049 1050 1051 1052
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
1053
            if k in [
1054 1055 1056 1057 1058 1059
                "_dst_main_program",
                "_dst_startup_program",
                "_cur_src_op",
                "_work_block",
                "_main_block",
                "_startup_block",
1060
            ]:
Z
zhaoyingli 已提交
1061 1062 1063 1064 1065
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

1066 1067
    @property
    def dst_main_program(self):
1068 1069
        return self._dst_main_program

1070 1071 1072 1073
    @dst_main_program.setter
    def dst_main_program(self, prog):
        self._dst_main_program = prog
        self._main_block = prog.blocks[0]
1074

1075 1076 1077
    @property
    def main_block(self):
        return self._main_block
1078

1079 1080 1081
    @property
    def dst_startup_program(self):
        return self._dst_startup_program
1082

1083 1084 1085 1086
    @dst_startup_program.setter
    def dst_startup_program(self, prog):
        self._dst_startup_program = prog
        self._startup_block = prog.blocks[0]
1087

1088 1089 1090
    @property
    def startup_block(self):
        return self._startup_block
1091

1092 1093 1094 1095
    @property
    def work_block(self):
        assert self._work_block is not None
        return self._work_block
1096

1097 1098 1099 1100
    @work_block.setter
    def work_block(self, block):
        assert block is not None
        self._work_block = block
1101

1102 1103 1104
    @property
    def cur_src_op(self):
        assert self._cur_src_op is not None
1105 1106
        return self._cur_src_op

1107 1108 1109
    def in_backward_phase(self):
        return self._exceed_backward_init_op

1110
    def prepare_context(self, src_op):
1111

1112
        self._cur_src_op = src_op
1113

1114 1115 1116
        if is_loss_grad_op(src_op):
            self._exceed_backward_init_op = True

1117 1118 1119 1120 1121
        # build input varname mapping
        kinputs = {}
        for input_name in src_op.desc.input_names():
            varnames = []
            for varname in src_op.desc.input(input_name):
1122 1123
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
1124 1125 1126 1127 1128 1129 1130
            kinputs[input_name] = varnames

        # build output varname mapping
        koutputs = {}
        for output_name in src_op.desc.output_names():
            varnames = []
            for varname in src_op.desc.output(output_name):
1131 1132
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
1133 1134 1135
            koutputs[output_name] = varnames

        return kinputs, koutputs
1136 1137


1138
class BlockState:
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154
    def __init__(self):
        self.nblock = 0
        self.forward_indices = []
        self.backward_indices = []
        self.backward_to_forward_index_map = {}

    def parse_forward_blocks(self, program):

        while program.current_block_idx != 0:
            program._rollback()

        assert program.current_block_idx == 0

        for idx, block in enumerate(program.blocks):

            assert idx == block.idx, "index doesn't match"
1155 1156 1157 1158 1159
            assert (
                block.forward_block_idx == -1
            ), "forward_block_idx of forward block [{}] is not [{}]".format(
                idx, block.forward_block_idx
            )
1160 1161 1162 1163 1164 1165 1166 1167
            self.forward_indices.append(idx)
            self.nblock += 1

        assert self.nblock >= 1

    def parse_backward_blocks(self, program):

        assert 0 in self.forward_indices, "forward block idx are{}".format(
1168 1169
            self.forward_indices
        )
1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183
        self.backward_to_forward_index_map[0] = 0

        for idx, block in enumerate(program.blocks):

            if idx < len(self.forward_indices):
                continue

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx in self.forward_indices
            self.backward_indices.append(idx)
            self.backward_to_forward_index_map[idx] = block.forward_block_idx
            self.nblock += 1

        assert self.nblock == len(program.blocks)