dist_context.py 43.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
from collections import defaultdict
from paddle.fluid import framework
18
from paddle.fluid.framework import set_flags
19
from paddle.fluid import core
20
from paddle.distributed.passes import PassContext
21 22 23
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
24
from .utils import is_loss_grad_op
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42

# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None


def get_default_distributed_context():
    global _g_default_distributed_context
    if _g_default_distributed_context is None:
        dist_context = DistributedContext()
        set_default_distributed_context(dist_context)
    return _g_default_distributed_context


def set_default_distributed_context(dist_context):
    global _g_default_distributed_context
    _g_default_distributed_context = dist_context


43 44 45 46
def _node_id(node):
    return (node.node.graph_id(), node.node.id())


47 48 49 50 51 52
class DistributedContext:
    """
    DistributedContext is used to collect related distributed information for program and graph.
    One auto-parallel run should use its own DistributedContext to avoid interfering other run.
    """

53 54 55
    def __init__(self,
                 serial_main_prog=None,
                 serial_startup_prog=None,
56
                 serial_optimizer=None,
57
                 serial_loss=None,
58 59 60
                 feed_vars={},
                 fetch_vars={},
                 cluster=None,
61 62 63 64
                 strategy=None):
        # Data members related to original programs (unchanged)
        self._original_serial_main_program = serial_main_prog
        self._original_serial_startup_program = serial_startup_prog
65
        self._original_serial_optimizer = serial_optimizer
66
        self._original_serial_loss = serial_loss
67 68
        self._original_serial_feed_vars = feed_vars
        self._original_serial_fetch_vars = fetch_vars
69 70 71 72

        # Data members related to programs (changed)
        self._serial_main_program = None
        self._serial_startup_program = None
73 74 75 76
        self._serial_loss = None
        self._serial_optimizer = None
        self._serial_feed_vars = {}
        self._serial_fetch_vars = {}
77 78

        # Data members related to the program
79 80
        self._dist_tensors_for_program = {}
        self._dist_ops_for_program = {}
81 82

        # Data members related to the graph
83
        self._serial_graph = None
84 85
        self._dist_tensors_for_graph = {}
        self._dist_ops_for_graph = {}
86 87
        self._node_id_to_tensor_id = {}
        self._node_id_to_op_id = {}
88

89
        # Data members related to the distributed programs
90
        # Distributed programs
91 92
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
93 94
        self._dist_op_context = DistributedOperatorContext()
        self._process_meshes = []
95

96
        self._cluster = cluster
97 98 99 100
        self._strategy = strategy

        # Pass Context
        self._pass_context = PassContext()
101
        self._block_state = BlockState()
102 103 104 105 106 107 108 109

        # Other data members
        self._serial_ordered_tensor_nodes = []
        self._serial_ordered_op_nodes = []
        self._serial_ordered_nodes = []
        # self._tensor_id_to_tensor_node_ids = {}

        self._is_initialized = False
110
        #TODO: need a better way to remove the following flag
111 112 113 114 115 116 117
        self._need_copy_dist_attr_to_graph = False
        self._backup_pass_context_stack = []
        self._backup_block_state_stack = []
        self._backup_dist_tensors_for_program_stack = []
        self._backup_dist_ops_for_program_stack = []
        self._backup_serial_main_program_stack = []
        self._backup_serial_startup_program_stack = []
118

119 120 121
        # flag whether scale gradient with dp size
        self._gradient_scale = True

122 123 124
        # A flag indicates whether the used parallelism is data parallel
        self._data_parallel = False

125
    @property
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140
    def serial_main_program(self):
        return self._serial_main_program

    @property
    def serial_startup_program(self):
        return self._serial_startup_program

    @property
    def serial_loss(self):
        return self._serial_loss

    @property
    def serial_optimizer(self):
        return self._serial_optimizer

141 142 143 144 145 146 147
    @property
    def serial_feed_vars(self):
        return self._serial_feed_vars

    @property
    def serial_fetch_vars(self):
        return self._serial_fetch_vars
148

149 150 151 152 153 154 155 156 157 158 159 160
    @property
    def dist_main_programs(self):
        return self._dist_main_programs

    @property
    def dist_startup_programs(self):
        return self._dist_startup_programs

    @property
    def cluster(self):
        return self._cluster

161 162 163 164
    @property
    def strategy(self):
        return self._strategy

165 166 167 168
    @property
    def serial_graph(self):
        return self._serial_graph

169 170 171 172
    @property
    def serial_ordered_nodes(self):
        return self._serial_ordered_nodes

173 174 175 176
    @property
    def process_meshes(self):
        return self._process_meshes

177 178 179 180
    @property
    def pass_context(self):
        return self._pass_context

181 182 183 184
    @property
    def dist_op_context(self):
        return self._dist_op_context

185 186 187 188
    @property
    def block_state(self):
        return self._block_state

189
    @property
190
    def has_annotation(self):
191 192 193
        return len(self._dist_tensors_for_program) or len(
            self._dist_ops_for_program)

194 195 196 197 198 199 200 201
    @property
    def gradient_scale(self):
        return self._gradient_scale

    @gradient_scale.setter
    def gradient_scale(self, gs):
        self._gradient_scale = gs

202 203 204 205 206 207 208 209
    @property
    def data_parallel(self):
        return self._data_parallel

    @data_parallel.setter
    def data_parallel(self, dp):
        self._data_parallel = dp

210 211 212 213 214
    def _backup_serial_info(self, mode):
        self._backup_serial_main_program_stack.append(
            self._serial_main_program.clone())
        self._backup_serial_startup_program_stack.append(
            self._serial_startup_program.clone())
215 216
        self._backup_pass_context_stack.append(copy.deepcopy(
            self._pass_context))
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
        self._backup_block_state_stack.append(copy.deepcopy(self._block_state))

    def _backup_dist_info(self, mode):
        self._backup_dist_tensors_for_program_stack.append(
            copy.deepcopy(self._dist_tensors_for_program))
        self._backup_dist_ops_for_program_stack.append(
            copy.deepcopy(self._dist_ops_for_program))

    def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
        # Use this function carefully
        if serial:
            self._backup_serial_info(serial_mode)
        if dist:
            self._backup_dist_info(dist_mode)

232
    def _restore_serial_loss(self):
233 234
        if self._original_serial_loss:
            if isinstance(self._original_serial_loss, list):
235 236 237 238 239 240 241 242 243 244 245
                if len(self._original_serial_loss) == 1:
                    loss = self._original_serial_loss[0]
                    block_idx = loss.block.idx
                    var_name = loss.name
                    var = self._serial_main_program.blocks[
                        block_idx]._var_recursive(var_name)
                    self._serial_loss = var
                elif len(self._original_serial_loss) == 0:
                    self._serial_loss = []
                else:
                    raise ValueError("multi loss vars are not supported.")
246
            else:
247 248 249 250 251 252
                block_idx = self._original_serial_loss.block.idx
                var_name = self._original_serial_loss.name
                var = self._serial_main_program.blocks[
                    block_idx]._var_recursive(var_name)
                self._serial_loss = var

253
    def _restore_serial_feed_vars(self):
254 255 256 257 258 259 260 261 262 263
        for key, var_list in self._original_serial_feed_vars.items():
            new_var_list = []
            for var in var_list:
                block_idx = var.block.idx
                var_name = var.name
                var = self._serial_main_program.blocks[
                    block_idx]._var_recursive(var_name)
                new_var_list.append(var)
            self._serial_feed_vars[key] = new_var_list

264
    def _restore_serial_fetch_vars(self):
265 266
        for key, var_list in self._original_serial_fetch_vars.items():
            new_var_list = []
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
            # metrics is a list of list
            if key == "metrics":
                for inner_var_list in var_list:
                    new_inner_var_list = []
                    for var in inner_var_list:
                        block_idx = var.block.idx
                        var_name = var.name
                        var = self._serial_main_program.blocks[
                            block_idx]._var_recursive(var_name)
                        new_inner_var_list.append(var)
                    new_var_list.append(new_inner_var_list)
            else:
                for var in var_list:
                    block_idx = var.block.idx
                    var_name = var.name
                    var = self._serial_main_program.blocks[
                        block_idx]._var_recursive(var_name)
                    new_var_list.append(var)
285 286
            self._serial_fetch_vars[key] = new_var_list

287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304
    def _restore_serial_info(self, mode="to_backup"):
        if mode == "to_backup":
            self._serial_main_program = self._backup_serial_main_program_stack.pop(
            )
            self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
            )
        elif mode == "to_original":
            assert self._original_serial_main_program is not None
            assert self._original_serial_startup_program is not None
            self._serial_main_program = self._original_serial_main_program.clone(
            )
            self._serial_startup_program = self._original_serial_startup_program.clone(
            )

        self._restore_serial_loss()
        self._restore_serial_feed_vars()
        self._restore_serial_fetch_vars()
        self._serial_optimizer = self._original_serial_optimizer
305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
        self._pass_context = self._backup_pass_context_stack.pop()
        self._block_state = self._backup_block_state_stack.pop()

    def _restore_dist_info(self, mode="to_backup"):
        if mode == "to_backup":
            self._dist_tensors_for_program = self._backup_dist_tensors_for_program_stack.pop(
            )
            self._dist_ops_for_program = self._backup_dist_ops_for_program_stack.pop(
            )
        elif mode == "to_original":
            assert self._original_dist_tensors_for_program
            assert self._original_dist_ops_for_program
            self._dist_tensors_for_program = copy.deepcopy(
                self._original_dist_tensors_for_program)
            self._dist_ops_for_program = copy.deepcopy(
                self._original_dist_ops_for_program)
        elif mode == "to_default":
            new_tensors_ids = []
            for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
            ):
                if tensor_id in self._tensors_ids:
                    dist_tensor.dist_attr.reset()
                else:
                    new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                if op_id in self._ops_ids:
                    dist_op.dist_attr.reset()
                else:
                    new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        else:
            new_tensors_ids = []
            for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
            ):
                new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
        self._dist_op_context = DistributedOperatorContext()
        self._need_copy_dist_attr_to_graph = True
        self._process_meshes = []

    def _restore(self,
                 serial=True,
                 serial_mode="to_backup",
                 dist=True,
                 dist_mode="to_backup"):
        # Use this function carefully
        if serial:
            self._restore_serial_info(serial_mode)
        if dist:
            self._restore_dist_info(dist_mode)

368
    def initialize(self, with_graph=True):
369 370
        if not self._is_initialized:
            if not self._serial_main_program:
371 372 373
                if self._original_serial_main_program:
                    self._serial_main_program = self._original_serial_main_program.clone(
                    )
374
            if not self._serial_startup_program:
375 376 377
                if self._original_serial_startup_program:
                    self._serial_startup_program = self._original_serial_startup_program.clone(
                    )
378
            if not self._serial_loss:
379
                self._restore_serial_loss()
380 381 382
            if not self._serial_optimizer:
                self._serial_optimizer = self._original_serial_optimizer
            if not self._serial_feed_vars:
383
                self._restore_serial_feed_vars()
384
            if not self._serial_fetch_vars:
385
                self._restore_serial_fetch_vars()
386

387
            self._init_dist_attr_for_program()
388 389 390 391 392
            # Backup the original distributed information for later restore
            self._original_dist_tensors_for_program = copy.deepcopy(
                self._dist_tensors_for_program)
            self._original_dist_ops_for_program = copy.deepcopy(
                self._dist_ops_for_program)
393 394 395
            self._tensors_ids = list(self._dist_tensors_for_program.keys())
            self._ops_ids = list(self._dist_ops_for_program.keys())
            self._is_initialized = True
396 397 398 399 400 401 402 403 404

            if with_graph:
                set_flags({"FLAGS_convert_all_blocks": True})
                self._serial_graph = framework.IrGraph(
                    core.Graph(self._serial_main_program.desc))
                self._init_dist_attr_for_graph()
                self._need_copy_dist_attr_to_graph = False

        if self._need_copy_dist_attr_to_graph and with_graph:
405
            self.copy_dist_attr_from_program_to_graph()
406

407 408 409 410 411 412 413 414
    def add_process_mesh(self, process_mesh):
        assert isinstance(process_mesh, ProcessMesh), \
            'The type of dim_mapping must be ProcessMesh.'
        if process_mesh not in self.process_meshes:
            self._process_meshes.append(process_mesh)

    def add_dist_tensor_for_program(self, dist_tensor):
        inner_serial_tensor = dist_tensor.serial_tensor
415
        inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
416 417 418 419
        self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor

    def add_dist_op_for_program(self, dist_op):
        inner_serial_op = dist_op.serial_op
420
        inner_serial_op_id = inner_serial_op.desc.original_id()
421 422 423 424
        self._dist_ops_for_program[inner_serial_op_id] = dist_op

    def get_dist_tensor_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
425 426 427 428 429
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor
        else:
            serial_tensor_id = serial_tensor.desc.original_id()
430 431
            dist_tensor = self._dist_tensors_for_program.get(
                serial_tensor_id, None)
432 433 434 435
            if dist_tensor:
                return dist_tensor
            else:
                return None
436 437

    def get_dist_tensor_for_graph(self, serial_tensor_node):
438
        serial_tensor_node_id = _node_id(serial_tensor_node)
439 440
        return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)

441 442 443 444 445 446 447 448 449 450 451 452
    def get_dist_op_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op
        else:
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op
            else:
                return None
453

454 455 456 457 458
    def del_dist_op_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        if self._dist_ops_for_program.get(serial_tensor_id, None):
            del self._dist_ops_for_program[serial_tensor_id]

459
    def get_dist_op_for_graph(self, serial_op_node):
460
        serial_op_node_id = _node_id(serial_op_node)
461
        return self._dist_ops_for_graph.get(serial_op_node_id, None)
462 463 464 465 466 467 468

    def get_tensor_dist_attr_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
469
            serial_tensor_id = serial_tensor.desc.original_id()
470 471
            dist_tensor = self._dist_tensors_for_program.get(
                serial_tensor_id, None)
472 473 474 475
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
476

477 478 479 480 481 482 483
    def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
        dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

484 485 486 487 488
    def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
        dist_tensor = DistributedTensor(serial_tensor, dist_attr)
        self.add_dist_tensor_for_program(dist_tensor)

    def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
489
        serial_tensor_node_id = _node_id(serial_tensor_node)
490 491 492 493 494 495 496 497 498 499 500 501 502
        dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
                                                       None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

    def get_op_dist_attr_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
503 504 505 506 507 508
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
509

510 511 512 513 514 515 516
    def get_op_dist_attr_for_program_with_id(self, op_id):
        dist_op = self._dist_ops_for_program.get(op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

517 518 519 520 521
    def set_op_dist_attr_for_program(self, serial_op, dist_attr):
        dist_op = DistributedOperator(serial_op, dist_attr)
        self.add_dist_op_for_program(dist_op)

    def get_op_dist_attr_for_graph(self, serial_op_node):
522
        serial_op_node_id = _node_id(serial_op_node)
523 524 525 526 527 528
        dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

529 530
    def get_dist_attr_for_graph(self, serial_node):
        if serial_node.is_var() and serial_node.var() is not None:
531
            serial_tensor_node_id = _node_id(serial_node)
532 533 534 535 536 537 538
            dist_tensor = self._dist_tensors_for_graph.get(
                serial_tensor_node_id, None)
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
        if serial_node.is_op() and serial_node.op() is not None:
539
            serial_op_node_id = _node_id(serial_node)
540 541 542 543 544 545
            dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
        return None
546

547
    def _init_dist_attr_for_program(self, no_default=False):
548
        # Copy the dist tensors and dist ops annotated by users from the default context
549 550 551 552 553
        if not no_default:
            default_ctx = get_default_distributed_context()
            self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
        else:
            default_ctx = self
554 555
        # Copy the data parallel flag from the default context
        self._data_parallel = default_ctx.data_parallel
556
        for block in self._serial_main_program.blocks:
557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
            for tensor in block.vars.values():
                # Copy the distributed tensors in the default context
                default_dist_tensor = default_ctx.get_dist_tensor_for_program(
                    tensor)
                if default_dist_tensor and default_ctx is not self:
                    self.add_dist_tensor_for_program(default_dist_tensor)
                current_dist_tensor = self.get_dist_tensor_for_program(tensor)
                if current_dist_tensor is None:
                    dist_tensor = DistributedTensor(tensor)
                    self.add_dist_tensor_for_program(dist_tensor)
            for op in block.ops:
                # Copy the distributed operators in the default context
                default_dist_op = default_ctx.get_dist_op_for_program(op)
                if default_dist_op and default_ctx is not self:
                    self.add_dist_op_for_program(default_dist_op)
                current_dist_op = self.get_dist_op_for_program(op)
                if current_dist_op is None:
                    dist_op = DistributedOperator(op)
                    self.add_dist_op_for_program(dist_op)
576 577 578 579
        self._original_dist_tensors_for_program = copy.deepcopy(
            self._dist_tensors_for_program)
        self._original_dist_ops_for_program = copy.deepcopy(
            self._dist_ops_for_program)
580

581
    def _order_nodes_by_program_order(self):
582

583 584
        def _contains(nodes, target_node):
            for node in nodes:
585
                if _node_id(node) == _node_id(target_node):
586 587 588
                    return True
            return False

589 590 591 592 593 594
        serial_ordered_tensor_nodes = []
        serial_ordered_op_nodes = []
        all_nodes = []
        for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
            for node in graph.all_nodes():
                all_nodes.append(node)
595 596
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
597
                serial_ordered_tensor_nodes.append(node)
598
            if node.is_op() and node.op() is not None:
599 600 601 602 603 604 605 606 607 608
                serial_ordered_op_nodes.append(node)
        serial_ordered_tensor_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        serial_ordered_op_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        num_nodes_before = len(serial_ordered_tensor_nodes) + len(
            serial_ordered_op_nodes)

        new_serial_ordered_tensor_nodes = []
        new_serial_ordered_op_nodes = []
609
        new_serial_ordered_nodes = []
610
        for op_node in serial_ordered_op_nodes:
611 612 613 614
            tensor_nodes = []
            for tensor_node in op_node.inputs:
                if tensor_node.is_var() \
                    and tensor_node.var() is not None \
615
                    and not _contains(new_serial_ordered_nodes, tensor_node):
616
                    tensor_nodes.append(tensor_node)
617
                    new_serial_ordered_tensor_nodes.append(tensor_node)
618
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
619 620
            new_serial_ordered_nodes.extend(tensor_nodes)
            new_serial_ordered_nodes.append(op_node)
621
            new_serial_ordered_op_nodes.append(op_node)
622 623 624 625
            tensor_nodes = []
            for tensor_node in op_node.outputs:
                if tensor_node.is_var() \
                    and tensor_node.var() is not None \
626
                    and not _contains(new_serial_ordered_nodes, tensor_node):
627
                    tensor_nodes.append(tensor_node)
628 629
                    new_serial_ordered_tensor_nodes.append(tensor_node)
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
630
            new_serial_ordered_nodes.extend(tensor_nodes)
631 632 633 634 635 636
        new_serial_ordered_tensor_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        new_serial_ordered_op_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
        self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
637
        self._serial_ordered_nodes = new_serial_ordered_nodes
638 639 640 641 642 643 644 645 646 647 648
        assert len(self._serial_ordered_nodes) == len(
            self._serial_ordered_tensor_nodes) + len(
                self._serial_ordered_op_nodes)
        self._serial_orphan_tensor_nodes = []
        for tensor_node in serial_ordered_tensor_nodes:
            if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
                self._serial_orphan_tensor_nodes.append(tensor_node)
        if len(self._serial_ordered_nodes) != num_nodes_before:
            print(
                "WARNING: there are some orphan tensors or ops which are not used in the execution."
            )
649

650 651 652
    def _init_dist_attr_for_graph(self):
        # Convert program to graph and initialize the distributed attributes
        self._order_nodes_by_program_order()
653
        for node in self.serial_ordered_nodes:
654
            if node.is_var() and node.var() is not None:
655 656 657 658 659 660 661
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
                for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
                ):
                    if tensor_id == cur_tensor_id \
                        or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
                        dist_tensor = cur_dist_tensor
662 663
                        self._node_id_to_tensor_id[_node_id(
                            node)] = cur_tensor_id
664 665
                assert dist_tensor is not None, \
                    "Tensor must have a distributed tensor after the initialization for program."
666
                serial_tensor_node_id = _node_id(node)
667 668 669 670
                new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
                                                    dist_tensor.dist_attr)
                self._dist_tensors_for_graph[
                    serial_tensor_node_id] = new_dist_tensor
671
            if node.is_op() and node.op() is not None:
672 673 674 675 676 677 678
                dist_op = None
                op_id = node.node.original_desc_id()
                for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
                ):
                    if op_id == cur_op_id \
                        or op_id == cur_dist_op.serial_op.desc.original_id():
                        dist_op = cur_dist_op
679
                        self._node_id_to_op_id[_node_id(node)] = cur_op_id
680 681
                assert dist_op is not None, \
                    "Operator must have a distributed operator after the initialization for program."
682
                serial_op_node_id = _node_id(node)
683 684 685
                new_dist_op = DistributedOperator(dist_op.serial_op,
                                                  dist_op.dist_attr)
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
686 687 688 689 690 691 692 693 694

    def clear_dist_info_for_program(self):
        self._dist_tensors_for_program.clear()
        self._dist_ops_for_program.clear()

    def clear_dist_info_for_graph(self):
        self._dist_tensors_for_graph.clear()
        self._dist_ops_for_graph.clear()

695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
    def copy_dist_attr_from_program_to_graph(self):
        for node in self.serial_ordered_nodes:
            if node.is_var() and node.var() is not None:
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
                for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
                ):
                    if tensor_id == cur_tensor_id \
                        or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
                        dist_tensor = cur_dist_tensor
                assert dist_tensor is not None, \
                    "Tensor must have a distributed tensor after the initialization for program."
                serial_tensor_node_id = _node_id(node)
                new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
                                                    dist_tensor.dist_attr)
                self._dist_tensors_for_graph[
                    serial_tensor_node_id] = new_dist_tensor
            if node.is_op() and node.op() is not None:
                dist_op = None
                op_id = node.node.original_desc_id()
                for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
                ):
                    if op_id == cur_op_id \
                        or op_id == cur_dist_op.serial_op.desc.original_id():
                        dist_op = cur_dist_op
                assert dist_op is not None, \
                    "Operator must have a distributed operator after the initialization for program."
                serial_op_node_id = _node_id(node)
                new_dist_op = DistributedOperator(dist_op.serial_op,
                                                  dist_op.dist_attr)
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

727
    def copy_dist_attr_from_graph_to_program(self):
728
        assert self._is_initialized, \
729 730
            "Both program and graph must be initialized."
        updated_tensors = {}
731 732
        # all_nodes = self._serial_graph.all_nodes()
        all_nodes = self._serial_ordered_nodes
733 734
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
735
                tensor_id = self._node_id_to_tensor_id[_node_id(node)]
736
                updated = updated_tensors.get(tensor_id, False)
737 738 739 740 741 742 743
                # If a var has multiples var nodes in graph, only use the first one for now
                if not updated:
                    tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
                        node)
                    dist_tensor_for_program = self._dist_tensors_for_program[
                        tensor_id]
                    dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
744
                    updated_tensors[tensor_id] = True
745
            if node.is_op() and node.op() is not None:
746
                op_id = self._node_id_to_op_id[_node_id(node)]
747 748 749
                op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
                dist_op_for_program = self._dist_ops_for_program[op_id]
                dist_op_for_program.dist_attr = op_dist_attr_for_graph
750
        # TODO: the completion algorithm will skipped orphan tensors,
751 752 753
        # here we just set there process_mesh to the first one.
        for orphan_node in self._serial_orphan_tensor_nodes:
            serial_tensor_id = orphan_node.var().id()
754 755
            dist_tensor = self._dist_tensors_for_program.get(
                serial_tensor_id, None)
756 757 758 759 760 761 762
            if dist_tensor:
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
            else:
                serial_tensor_id = orphan_node.var().original_id()
                dist_tensor = self._dist_tensors_for_program.get(
                    serial_tensor_id, None)
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
763 764 765 766 767

    def amend_dist_attr_for_program(self):
        for dist_tensor in self._dist_tensors_for_program.values():
            serial_tensor = dist_tensor.serial_tensor
            dist_attr = dist_tensor.dist_attr
768 769 770
            if serial_tensor.type == core.VarDesc.VarType.READER \
                or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
                or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
771 772 773 774 775
                tensor_shape = []
            else:
                tensor_shape = serial_tensor.shape
            dims_mapping = dist_attr.dims_mapping
            process_mesh_shape = dist_attr.process_mesh.topology
776
            process_mesh_processes = dist_attr.process_mesh.processes
777 778 779 780 781 782
            # If the dimension of tensor is less than the sharding dimension of process mesh,
            # we just amend the dimension mapping to -1. (Is this really OK?)
            for i in range(len(tensor_shape)):
                if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
                    and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
                    dims_mapping[i] = -1
783 784
                if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
                    dims_mapping[i] = -1
785 786 787 788

        for dist_op in self._dist_ops_for_program.values():
            serial_op = dist_op.serial_op
            dist_attr = dist_op.dist_attr
789 790
            process_mesh_shape = dist_attr.process_mesh.topology
            process_mesh_processes = dist_attr.process_mesh.processes
791 792 793 794 795
            for arg_name in serial_op.input_arg_names:
                if dist_op.get_serial_input(arg_name) is None:
                    tensor_shape = []
                else:
                    if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
796
                        or dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
797 798 799 800 801 802 803 804 805 806 807
                        or dist_op.serial_op.type == "create_py_reader":
                        tensor_shape = []
                    else:
                        tensor_shape = dist_op.get_serial_input(arg_name).shape
                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
                    if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
                        and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
                        dims_mapping[i] = -1
808 809 810
                    if dims_mapping[i] != -1 and len(
                            process_mesh_processes) == 1:
                        dims_mapping[i] = -1
811
            for arg_name in serial_op.output_arg_names:
812 813 814
                if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
                    or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
                    or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.STEP_SCOPES:
815 816 817 818 819 820 821 822 823 824
                    tensor_shape = []
                else:
                    tensor_shape = dist_op.get_serial_output(arg_name).shape
                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
                    if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
                        and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
                        dims_mapping[i] = -1
825 826 827 828 829 830
                    if dims_mapping[i] != -1 and len(
                            process_mesh_processes) == 1:
                        dims_mapping[i] = -1
            if len(process_mesh_processes) == 1:
                dist_op.dist_attr.impl_type = "default"
                dist_op.dist_attr.impl_idx = 0
831 832

    def validate_dist_attr_for_program(self):
833
        if not self._is_initialized:
834 835
            assert False, \
                "Program must be initialized before validating its distributed attributes"
836
        for block in self.serial_main_program.blocks:
837 838
            for tensor in block.vars.values():
                dist_tensor = self.get_dist_tensor_for_program(tensor)
839 840 841
                assert dist_tensor is not None, \
                    "Tensor {} does not have a distributed attribute.".format(
                        dist_tensor.serial_tensor.name)
842 843
                if (dist_tensor
                        is not None) and (not dist_tensor.validate_dist_attr()):
844
                    assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
C
caozhou 已提交
845 846 847 848
                        dist_tensor.serial_tensor.name,
                        dist_tensor.serial_tensor.desc.id(),
                        dist_tensor.serial_tensor.desc.original_id(),
                        dist_tensor.dist_attr)
849 850
            for op in block.ops:
                dist_op = self.get_dist_op_for_program(op)
851 852 853
                assert dist_op is not None, \
                    "Operator {} does not have a distributed attribute.".format(
                        dist_op.serial_op.type)
854
                if (dist_op is not None) and (not dist_op.validate_dist_attr()):
855
                    assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
856
                        dist_op.serial_op.type, dist_op.serial_op.desc.id(),
857
                        dist_op.serial_op.desc.original_id(), dist_op.dist_attr)
858 859
        return True

Z
zhaoyingli 已提交
860 861 862 863 864
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
865 866 867 868 869
            if k in [
                "_original_serial_main_program", "_original_serial_startup_program", \
                "_serial_main_program", "_serial_startup_program", "_serial_graph", \
                "_dist_main_programs", "_dist_startup_programs", \
                "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
870 871
                "_serial_ordered_op_nodes", "_original_serial_loss", \
                "_original_serial_feed_vars", "_original_serial_fetch_vars", \
Z
zhaoyingli 已提交
872
                "_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \
873 874
                "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
                "_pass_context"]:
Z
zhaoyingli 已提交
875 876 877
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
878 879 880 881

        # update dist tensor's dist_context
        for key in result._dist_tensors_for_program.keys():
            result._dist_tensors_for_program[key]._dist_context = result
Z
zhaoyingli 已提交
882 883
        return result

884 885 886 887 888 889 890 891 892

class DistributedOperatorContext:
    """
    DistributedOperatorContext is used to create a dist op desc in Program.
    Every time to create a new dist op, the context should be updated for it accordingly.
    """

    def __init__(self):
        self._dst_main_program = None
893
        self._main_block = None
894
        self._dst_startup_program = None
895
        self._startup_block = None
896 897
        self._cur_src_op = None
        self._cur_dist_attr = None
898
        self.grad_op_id_to_op_id = {}
899
        self.grad_var_to_var = defaultdict(dict)
900
        self._work_block = None
901
        self.already_init_sync_vars = set()
902 903
        self.varname_mapping = None
        self.rank_id = None
904 905 906 907 908
        # NOTE Support correct parallelism for high-order differential model.
        # by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True,
        # it means we are in Backward phase.
        # And the final sulotion should be revise high-order differential logic for these two phases in future.
        self._exceed_backward_init_op = False
909

Z
zhaoyingli 已提交
910 911 912 913 914
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
915 916 917 918
            if k in [
                    "_dst_main_program", "_dst_startup_program", "_cur_src_op",
                    "_work_block", "_main_block", "_startup_block"
            ]:
Z
zhaoyingli 已提交
919 920 921 922 923
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

924 925
    @property
    def dst_main_program(self):
926 927
        return self._dst_main_program

928 929 930 931
    @dst_main_program.setter
    def dst_main_program(self, prog):
        self._dst_main_program = prog
        self._main_block = prog.blocks[0]
932

933 934 935
    @property
    def main_block(self):
        return self._main_block
936

937 938 939
    @property
    def dst_startup_program(self):
        return self._dst_startup_program
940

941 942 943 944
    @dst_startup_program.setter
    def dst_startup_program(self, prog):
        self._dst_startup_program = prog
        self._startup_block = prog.blocks[0]
945

946 947 948
    @property
    def startup_block(self):
        return self._startup_block
949

950 951 952 953
    @property
    def work_block(self):
        assert self._work_block is not None
        return self._work_block
954

955 956 957 958
    @work_block.setter
    def work_block(self, block):
        assert block is not None
        self._work_block = block
959

960 961 962
    @property
    def cur_src_op(self):
        assert self._cur_src_op is not None
963 964
        return self._cur_src_op

965 966 967
    def in_backward_phase(self):
        return self._exceed_backward_init_op

968
    def prepare_context(self, src_op):
969

970
        self._cur_src_op = src_op
971

972 973 974
        if is_loss_grad_op(src_op):
            self._exceed_backward_init_op = True

975 976 977 978 979
        # build input varname mapping
        kinputs = {}
        for input_name in src_op.desc.input_names():
            varnames = []
            for varname in src_op.desc.input(input_name):
980 981
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
982 983 984 985 986 987 988
            kinputs[input_name] = varnames

        # build output varname mapping
        koutputs = {}
        for output_name in src_op.desc.output_names():
            varnames = []
            for varname in src_op.desc.output(output_name):
989 990
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
991 992 993
            koutputs[output_name] = varnames

        return kinputs, koutputs
994 995 996


class BlockState(object):
997

998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
    def __init__(self):
        self.nblock = 0
        self.forward_indices = []
        self.backward_indices = []
        self.backward_to_forward_index_map = {}

    def parse_forward_blocks(self, program):

        while program.current_block_idx != 0:
            program._rollback()

        assert program.current_block_idx == 0

        for idx, block in enumerate(program.blocks):

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx == -1, "forward_block_idx of forward block [{}] is not [{}]".format(
                idx, block.forward_block_idx)
            self.forward_indices.append(idx)
            self.nblock += 1

        assert self.nblock >= 1

    def parse_backward_blocks(self, program):

        assert 0 in self.forward_indices, "forward block idx are{}".format(
            self.forward_indices)
        self.backward_to_forward_index_map[0] = 0

        for idx, block in enumerate(program.blocks):

            if idx < len(self.forward_indices):
                continue

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx in self.forward_indices
            self.backward_indices.append(idx)
            self.backward_to_forward_index_map[idx] = block.forward_block_idx
            self.nblock += 1

        assert self.nblock == len(program.blocks)