dist_context.py 43.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

import copy
from collections import defaultdict
17
import paddle.fluid
18
from paddle.fluid import framework
19
from paddle.fluid.framework import get_flags, set_flags
20
from paddle.fluid import core
21
from paddle.distributed.passes import PassContext
22 23 24 25 26
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .process_mesh import ProcessMesh
27
from .utils import is_loss_grad_op, is_loss_op
28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45

# There always exists a default context for user. And user can set it to another one.
_g_default_distributed_context = None


def get_default_distributed_context():
    global _g_default_distributed_context
    if _g_default_distributed_context is None:
        dist_context = DistributedContext()
        set_default_distributed_context(dist_context)
    return _g_default_distributed_context


def set_default_distributed_context(dist_context):
    global _g_default_distributed_context
    _g_default_distributed_context = dist_context


46 47 48 49
def _node_id(node):
    return (node.node.graph_id(), node.node.id())


50 51 52 53 54 55
class DistributedContext:
    """
    DistributedContext is used to collect related distributed information for program and graph.
    One auto-parallel run should use its own DistributedContext to avoid interfering other run.
    """

56 57 58
    def __init__(self,
                 serial_main_prog=None,
                 serial_startup_prog=None,
59
                 serial_optimizer=None,
60
                 serial_loss=None,
61 62 63
                 feed_vars={},
                 fetch_vars={},
                 cluster=None,
64 65 66 67
                 strategy=None):
        # Data members related to original programs (unchanged)
        self._original_serial_main_program = serial_main_prog
        self._original_serial_startup_program = serial_startup_prog
68
        self._original_serial_optimizer = serial_optimizer
69
        self._original_serial_loss = serial_loss
70 71
        self._original_serial_feed_vars = feed_vars
        self._original_serial_fetch_vars = fetch_vars
72 73 74 75

        # Data members related to programs (changed)
        self._serial_main_program = None
        self._serial_startup_program = None
76 77 78 79
        self._serial_loss = None
        self._serial_optimizer = None
        self._serial_feed_vars = {}
        self._serial_fetch_vars = {}
80
        self._lr_optimizer = None  # record the optimzier holding lr_scheduler
81 82

        # Data members related to the program
83 84
        self._dist_tensors_for_program = {}
        self._dist_ops_for_program = {}
85 86

        # Data members related to the graph
87
        self._serial_graph = None
88 89
        self._dist_tensors_for_graph = {}
        self._dist_ops_for_graph = {}
90 91
        self._node_id_to_tensor_id = {}
        self._node_id_to_op_id = {}
92

93
        # Data members related to the distributed programs
94
        # Distributed programs
95 96
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
97 98
        self._dist_op_context = DistributedOperatorContext()
        self._process_meshes = []
99

100
        self._cluster = cluster
101 102 103 104
        self._strategy = strategy

        # Pass Context
        self._pass_context = PassContext()
105
        self._block_state = BlockState()
106 107 108 109 110 111 112 113

        # Other data members
        self._serial_ordered_tensor_nodes = []
        self._serial_ordered_op_nodes = []
        self._serial_ordered_nodes = []
        # self._tensor_id_to_tensor_node_ids = {}

        self._is_initialized = False
114
        #TODO: need a better way to remove the following flag
115 116 117 118 119 120 121
        self._need_copy_dist_attr_to_graph = False
        self._backup_pass_context_stack = []
        self._backup_block_state_stack = []
        self._backup_dist_tensors_for_program_stack = []
        self._backup_dist_ops_for_program_stack = []
        self._backup_serial_main_program_stack = []
        self._backup_serial_startup_program_stack = []
122

123 124 125
        # flag whether scale gradient with dp size
        self._gradient_scale = True

126 127 128
        # A flag indicates whether the used parallelism is data parallel
        self._data_parallel = False

129
    @property
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
    def serial_main_program(self):
        return self._serial_main_program

    @property
    def serial_startup_program(self):
        return self._serial_startup_program

    @property
    def serial_loss(self):
        return self._serial_loss

    @property
    def serial_optimizer(self):
        return self._serial_optimizer

145 146 147 148 149 150 151
    @property
    def serial_feed_vars(self):
        return self._serial_feed_vars

    @property
    def serial_fetch_vars(self):
        return self._serial_fetch_vars
152

153 154 155 156 157 158 159 160 161 162 163 164
    @property
    def dist_main_programs(self):
        return self._dist_main_programs

    @property
    def dist_startup_programs(self):
        return self._dist_startup_programs

    @property
    def cluster(self):
        return self._cluster

165 166 167 168
    @property
    def strategy(self):
        return self._strategy

169 170 171 172
    @property
    def serial_graph(self):
        return self._serial_graph

173 174 175 176
    @property
    def serial_ordered_nodes(self):
        return self._serial_ordered_nodes

177 178 179 180
    @property
    def process_meshes(self):
        return self._process_meshes

181 182 183 184
    @property
    def pass_context(self):
        return self._pass_context

185 186 187 188
    @property
    def dist_op_context(self):
        return self._dist_op_context

189 190 191 192
    @property
    def block_state(self):
        return self._block_state

193
    @property
194
    def has_annotation(self):
195 196 197
        return len(self._dist_tensors_for_program) or len(
            self._dist_ops_for_program)

198 199 200 201 202 203 204 205
    @property
    def gradient_scale(self):
        return self._gradient_scale

    @gradient_scale.setter
    def gradient_scale(self, gs):
        self._gradient_scale = gs

206 207 208 209 210 211 212 213
    @property
    def data_parallel(self):
        return self._data_parallel

    @data_parallel.setter
    def data_parallel(self, dp):
        self._data_parallel = dp

214 215 216 217 218
    def _backup_serial_info(self, mode):
        self._backup_serial_main_program_stack.append(
            self._serial_main_program.clone())
        self._backup_serial_startup_program_stack.append(
            self._serial_startup_program.clone())
219 220
        self._backup_pass_context_stack.append(copy.deepcopy(
            self._pass_context))
221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
        self._backup_block_state_stack.append(copy.deepcopy(self._block_state))

    def _backup_dist_info(self, mode):
        self._backup_dist_tensors_for_program_stack.append(
            copy.deepcopy(self._dist_tensors_for_program))
        self._backup_dist_ops_for_program_stack.append(
            copy.deepcopy(self._dist_ops_for_program))

    def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None):
        # Use this function carefully
        if serial:
            self._backup_serial_info(serial_mode)
        if dist:
            self._backup_dist_info(dist_mode)

236
    def _restore_serial_loss(self):
237 238
        if self._original_serial_loss:
            if isinstance(self._original_serial_loss, list):
239 240 241 242 243 244 245 246 247 248 249
                if len(self._original_serial_loss) == 1:
                    loss = self._original_serial_loss[0]
                    block_idx = loss.block.idx
                    var_name = loss.name
                    var = self._serial_main_program.blocks[
                        block_idx]._var_recursive(var_name)
                    self._serial_loss = var
                elif len(self._original_serial_loss) == 0:
                    self._serial_loss = []
                else:
                    raise ValueError("multi loss vars are not supported.")
250
            else:
251 252 253 254 255 256
                block_idx = self._original_serial_loss.block.idx
                var_name = self._original_serial_loss.name
                var = self._serial_main_program.blocks[
                    block_idx]._var_recursive(var_name)
                self._serial_loss = var

257
    def _restore_serial_feed_vars(self):
258 259 260 261 262 263 264 265 266 267
        for key, var_list in self._original_serial_feed_vars.items():
            new_var_list = []
            for var in var_list:
                block_idx = var.block.idx
                var_name = var.name
                var = self._serial_main_program.blocks[
                    block_idx]._var_recursive(var_name)
                new_var_list.append(var)
            self._serial_feed_vars[key] = new_var_list

268
    def _restore_serial_fetch_vars(self):
269 270
        for key, var_list in self._original_serial_fetch_vars.items():
            new_var_list = []
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288
            # metrics is a list of list
            if key == "metrics":
                for inner_var_list in var_list:
                    new_inner_var_list = []
                    for var in inner_var_list:
                        block_idx = var.block.idx
                        var_name = var.name
                        var = self._serial_main_program.blocks[
                            block_idx]._var_recursive(var_name)
                        new_inner_var_list.append(var)
                    new_var_list.append(new_inner_var_list)
            else:
                for var in var_list:
                    block_idx = var.block.idx
                    var_name = var.name
                    var = self._serial_main_program.blocks[
                        block_idx]._var_recursive(var_name)
                    new_var_list.append(var)
289 290
            self._serial_fetch_vars[key] = new_var_list

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308
    def _restore_serial_info(self, mode="to_backup"):
        if mode == "to_backup":
            self._serial_main_program = self._backup_serial_main_program_stack.pop(
            )
            self._serial_startup_program = self._backup_serial_startup_program_stack.pop(
            )
        elif mode == "to_original":
            assert self._original_serial_main_program is not None
            assert self._original_serial_startup_program is not None
            self._serial_main_program = self._original_serial_main_program.clone(
            )
            self._serial_startup_program = self._original_serial_startup_program.clone(
            )

        self._restore_serial_loss()
        self._restore_serial_feed_vars()
        self._restore_serial_fetch_vars()
        self._serial_optimizer = self._original_serial_optimizer
309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
        self._pass_context = self._backup_pass_context_stack.pop()
        self._block_state = self._backup_block_state_stack.pop()

    def _restore_dist_info(self, mode="to_backup"):
        if mode == "to_backup":
            self._dist_tensors_for_program = self._backup_dist_tensors_for_program_stack.pop(
            )
            self._dist_ops_for_program = self._backup_dist_ops_for_program_stack.pop(
            )
        elif mode == "to_original":
            assert self._original_dist_tensors_for_program
            assert self._original_dist_ops_for_program
            self._dist_tensors_for_program = copy.deepcopy(
                self._original_dist_tensors_for_program)
            self._dist_ops_for_program = copy.deepcopy(
                self._original_dist_ops_for_program)
        elif mode == "to_default":
            new_tensors_ids = []
            for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
            ):
                if tensor_id in self._tensors_ids:
                    dist_tensor.dist_attr.reset()
                else:
                    new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                if op_id in self._ops_ids:
                    dist_op.dist_attr.reset()
                else:
                    new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        else:
            new_tensors_ids = []
            for tensor_id, dist_tensor in self._dist_tensors_for_program.items(
            ):
                new_tensors_ids.append(tensor_id)
            for tensor_id in new_tensors_ids:
                self._dist_tensors_for_program.pop(tensor_id)
            new_ops_ids = []
            for op_id, dist_op in self._dist_ops_for_program.items():
                new_ops_ids.append(op_id)
            for op_id in new_ops_ids:
                self._dist_ops_for_program.pop(op_id)
        self._dist_main_programs = {}
        self._dist_startup_programs = {}
        self._dist_op_context = DistributedOperatorContext()
        self._need_copy_dist_attr_to_graph = True
        self._process_meshes = []

    def _restore(self,
                 serial=True,
                 serial_mode="to_backup",
                 dist=True,
                 dist_mode="to_backup"):
        # Use this function carefully
        if serial:
            self._restore_serial_info(serial_mode)
        if dist:
            self._restore_dist_info(dist_mode)

372
    def initialize(self, with_graph=True):
373 374
        if not self._is_initialized:
            if not self._serial_main_program:
375 376 377
                if self._original_serial_main_program:
                    self._serial_main_program = self._original_serial_main_program.clone(
                    )
378
            if not self._serial_startup_program:
379 380 381
                if self._original_serial_startup_program:
                    self._serial_startup_program = self._original_serial_startup_program.clone(
                    )
382
            if not self._serial_loss:
383
                self._restore_serial_loss()
384 385 386
            if not self._serial_optimizer:
                self._serial_optimizer = self._original_serial_optimizer
            if not self._serial_feed_vars:
387
                self._restore_serial_feed_vars()
388
            if not self._serial_fetch_vars:
389
                self._restore_serial_fetch_vars()
390

391
            self._init_dist_attr_for_program()
392 393 394 395 396
            # Backup the original distributed information for later restore
            self._original_dist_tensors_for_program = copy.deepcopy(
                self._dist_tensors_for_program)
            self._original_dist_ops_for_program = copy.deepcopy(
                self._dist_ops_for_program)
397 398 399
            self._tensors_ids = list(self._dist_tensors_for_program.keys())
            self._ops_ids = list(self._dist_ops_for_program.keys())
            self._is_initialized = True
400 401 402 403 404 405 406 407 408

            if with_graph:
                set_flags({"FLAGS_convert_all_blocks": True})
                self._serial_graph = framework.IrGraph(
                    core.Graph(self._serial_main_program.desc))
                self._init_dist_attr_for_graph()
                self._need_copy_dist_attr_to_graph = False

        if self._need_copy_dist_attr_to_graph and with_graph:
409
            self.copy_dist_attr_from_program_to_graph()
410

411 412 413 414 415 416 417 418
    def add_process_mesh(self, process_mesh):
        assert isinstance(process_mesh, ProcessMesh), \
            'The type of dim_mapping must be ProcessMesh.'
        if process_mesh not in self.process_meshes:
            self._process_meshes.append(process_mesh)

    def add_dist_tensor_for_program(self, dist_tensor):
        inner_serial_tensor = dist_tensor.serial_tensor
419
        inner_serial_tensor_id = inner_serial_tensor.desc.original_id()
420 421 422 423
        self._dist_tensors_for_program[inner_serial_tensor_id] = dist_tensor

    def add_dist_op_for_program(self, dist_op):
        inner_serial_op = dist_op.serial_op
424
        inner_serial_op_id = inner_serial_op.desc.original_id()
425 426 427 428
        self._dist_ops_for_program[inner_serial_op_id] = dist_op

    def get_dist_tensor_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
429 430 431 432 433
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor
        else:
            serial_tensor_id = serial_tensor.desc.original_id()
434 435
            dist_tensor = self._dist_tensors_for_program.get(
                serial_tensor_id, None)
436 437 438 439
            if dist_tensor:
                return dist_tensor
            else:
                return None
440 441

    def get_dist_tensor_for_graph(self, serial_tensor_node):
442
        serial_tensor_node_id = _node_id(serial_tensor_node)
443 444
        return self._dist_tensors_for_graph.get(serial_tensor_node_id, None)

445 446 447 448 449 450 451 452 453 454 455 456
    def get_dist_op_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op
        else:
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op
            else:
                return None
457

458 459 460 461 462
    def del_dist_op_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        if self._dist_ops_for_program.get(serial_tensor_id, None):
            del self._dist_ops_for_program[serial_tensor_id]

463
    def get_dist_op_for_graph(self, serial_op_node):
464
        serial_op_node_id = _node_id(serial_op_node)
465
        return self._dist_ops_for_graph.get(serial_op_node_id, None)
466 467 468 469 470 471 472

    def get_tensor_dist_attr_for_program(self, serial_tensor):
        serial_tensor_id = serial_tensor.desc.id()
        dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
473
            serial_tensor_id = serial_tensor.desc.original_id()
474 475
            dist_tensor = self._dist_tensors_for_program.get(
                serial_tensor_id, None)
476 477 478 479
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
480

481 482 483 484 485 486 487
    def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
        dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

488 489 490 491 492
    def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
        dist_tensor = DistributedTensor(serial_tensor, dist_attr)
        self.add_dist_tensor_for_program(dist_tensor)

    def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
493
        serial_tensor_node_id = _node_id(serial_tensor_node)
494 495 496 497 498 499 500 501 502 503 504 505 506
        dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id,
                                                       None)
        if dist_tensor:
            return dist_tensor.dist_attr
        else:
            return None

    def get_op_dist_attr_for_program(self, serial_op):
        serial_op_id = serial_op.desc.id()
        dist_op = self._dist_ops_for_program.get(serial_op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
507 508 509 510 511 512
            serial_op_id = serial_op.desc.original_id()
            dist_op = self._dist_ops_for_program.get(serial_op_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
513

514 515 516 517 518 519 520
    def get_op_dist_attr_for_program_with_id(self, op_id):
        dist_op = self._dist_ops_for_program.get(op_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

521 522 523 524 525
    def set_op_dist_attr_for_program(self, serial_op, dist_attr):
        dist_op = DistributedOperator(serial_op, dist_attr)
        self.add_dist_op_for_program(dist_op)

    def get_op_dist_attr_for_graph(self, serial_op_node):
526
        serial_op_node_id = _node_id(serial_op_node)
527 528 529 530 531 532
        dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
        if dist_op:
            return dist_op.dist_attr
        else:
            return None

533 534
    def get_dist_attr_for_graph(self, serial_node):
        if serial_node.is_var() and serial_node.var() is not None:
535
            serial_tensor_node_id = _node_id(serial_node)
536 537 538 539 540 541 542
            dist_tensor = self._dist_tensors_for_graph.get(
                serial_tensor_node_id, None)
            if dist_tensor:
                return dist_tensor.dist_attr
            else:
                return None
        if serial_node.is_op() and serial_node.op() is not None:
543
            serial_op_node_id = _node_id(serial_node)
544 545 546 547 548 549
            dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
            if dist_op:
                return dist_op.dist_attr
            else:
                return None
        return None
550

551
    def _init_dist_attr_for_program(self, no_default=False):
552
        # Copy the dist tensors and dist ops annotated by users from the default context
553 554 555 556 557
        if not no_default:
            default_ctx = get_default_distributed_context()
            self._process_meshes = copy.deepcopy(default_ctx.process_meshes)
        else:
            default_ctx = self
558 559
        # Copy the data parallel flag from the default context
        self._data_parallel = default_ctx.data_parallel
560
        for block in self._serial_main_program.blocks:
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
            for tensor in block.vars.values():
                # Copy the distributed tensors in the default context
                default_dist_tensor = default_ctx.get_dist_tensor_for_program(
                    tensor)
                if default_dist_tensor and default_ctx is not self:
                    self.add_dist_tensor_for_program(default_dist_tensor)
                current_dist_tensor = self.get_dist_tensor_for_program(tensor)
                if current_dist_tensor is None:
                    dist_tensor = DistributedTensor(tensor)
                    self.add_dist_tensor_for_program(dist_tensor)
            for op in block.ops:
                # Copy the distributed operators in the default context
                default_dist_op = default_ctx.get_dist_op_for_program(op)
                if default_dist_op and default_ctx is not self:
                    self.add_dist_op_for_program(default_dist_op)
                current_dist_op = self.get_dist_op_for_program(op)
                if current_dist_op is None:
                    dist_op = DistributedOperator(op)
                    self.add_dist_op_for_program(dist_op)
580 581 582 583
        self._original_dist_tensors_for_program = copy.deepcopy(
            self._dist_tensors_for_program)
        self._original_dist_ops_for_program = copy.deepcopy(
            self._dist_ops_for_program)
584

585
    def _order_nodes_by_program_order(self):
586

587 588
        def _contains(nodes, target_node):
            for node in nodes:
589
                if _node_id(node) == _node_id(target_node):
590 591 592
                    return True
            return False

593 594 595 596 597 598
        serial_ordered_tensor_nodes = []
        serial_ordered_op_nodes = []
        all_nodes = []
        for idx, graph in enumerate(self._serial_graph.all_sub_graphs()):
            for node in graph.all_nodes():
                all_nodes.append(node)
599 600
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
601
                serial_ordered_tensor_nodes.append(node)
602
            if node.is_op() and node.op() is not None:
603 604 605 606 607 608 609 610 611 612
                serial_ordered_op_nodes.append(node)
        serial_ordered_tensor_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        serial_ordered_op_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        num_nodes_before = len(serial_ordered_tensor_nodes) + len(
            serial_ordered_op_nodes)

        new_serial_ordered_tensor_nodes = []
        new_serial_ordered_op_nodes = []
613
        new_serial_ordered_nodes = []
614
        for op_node in serial_ordered_op_nodes:
615 616 617 618
            tensor_nodes = []
            for tensor_node in op_node.inputs:
                if tensor_node.is_var() \
                    and tensor_node.var() is not None \
619
                    and not _contains(new_serial_ordered_nodes, tensor_node):
620
                    tensor_nodes.append(tensor_node)
621
                    new_serial_ordered_tensor_nodes.append(tensor_node)
622
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
623 624
            new_serial_ordered_nodes.extend(tensor_nodes)
            new_serial_ordered_nodes.append(op_node)
625
            new_serial_ordered_op_nodes.append(op_node)
626 627 628 629
            tensor_nodes = []
            for tensor_node in op_node.outputs:
                if tensor_node.is_var() \
                    and tensor_node.var() is not None \
630
                    and not _contains(new_serial_ordered_nodes, tensor_node):
631
                    tensor_nodes.append(tensor_node)
632 633
                    new_serial_ordered_tensor_nodes.append(tensor_node)
            tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
634
            new_serial_ordered_nodes.extend(tensor_nodes)
635 636 637 638 639 640
        new_serial_ordered_tensor_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        new_serial_ordered_op_nodes.sort(
            key=lambda node: node.node.original_desc_id())
        self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes
        self._serial_ordered_op_nodes = new_serial_ordered_op_nodes
641
        self._serial_ordered_nodes = new_serial_ordered_nodes
642 643 644 645 646 647 648 649 650 651 652
        assert len(self._serial_ordered_nodes) == len(
            self._serial_ordered_tensor_nodes) + len(
                self._serial_ordered_op_nodes)
        self._serial_orphan_tensor_nodes = []
        for tensor_node in serial_ordered_tensor_nodes:
            if not _contains(self._serial_ordered_tensor_nodes, tensor_node):
                self._serial_orphan_tensor_nodes.append(tensor_node)
        if len(self._serial_ordered_nodes) != num_nodes_before:
            print(
                "WARNING: there are some orphan tensors or ops which are not used in the execution."
            )
653

654 655 656
    def _init_dist_attr_for_graph(self):
        # Convert program to graph and initialize the distributed attributes
        self._order_nodes_by_program_order()
657
        for node in self.serial_ordered_nodes:
658
            if node.is_var() and node.var() is not None:
659 660 661 662 663 664 665
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
                for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
                ):
                    if tensor_id == cur_tensor_id \
                        or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
                        dist_tensor = cur_dist_tensor
666 667
                        self._node_id_to_tensor_id[_node_id(
                            node)] = cur_tensor_id
668 669
                assert dist_tensor is not None, \
                    "Tensor must have a distributed tensor after the initialization for program."
670
                serial_tensor_node_id = _node_id(node)
671 672 673 674
                new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
                                                    dist_tensor.dist_attr)
                self._dist_tensors_for_graph[
                    serial_tensor_node_id] = new_dist_tensor
675
            if node.is_op() and node.op() is not None:
676 677 678 679 680 681 682
                dist_op = None
                op_id = node.node.original_desc_id()
                for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
                ):
                    if op_id == cur_op_id \
                        or op_id == cur_dist_op.serial_op.desc.original_id():
                        dist_op = cur_dist_op
683
                        self._node_id_to_op_id[_node_id(node)] = cur_op_id
684 685
                assert dist_op is not None, \
                    "Operator must have a distributed operator after the initialization for program."
686
                serial_op_node_id = _node_id(node)
687 688 689
                new_dist_op = DistributedOperator(dist_op.serial_op,
                                                  dist_op.dist_attr)
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op
690 691 692 693 694 695 696 697 698

    def clear_dist_info_for_program(self):
        self._dist_tensors_for_program.clear()
        self._dist_ops_for_program.clear()

    def clear_dist_info_for_graph(self):
        self._dist_tensors_for_graph.clear()
        self._dist_ops_for_graph.clear()

699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730
    def copy_dist_attr_from_program_to_graph(self):
        for node in self.serial_ordered_nodes:
            if node.is_var() and node.var() is not None:
                dist_tensor = None
                tensor_id = node.node.original_desc_id()
                for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items(
                ):
                    if tensor_id == cur_tensor_id \
                        or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id():
                        dist_tensor = cur_dist_tensor
                assert dist_tensor is not None, \
                    "Tensor must have a distributed tensor after the initialization for program."
                serial_tensor_node_id = _node_id(node)
                new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
                                                    dist_tensor.dist_attr)
                self._dist_tensors_for_graph[
                    serial_tensor_node_id] = new_dist_tensor
            if node.is_op() and node.op() is not None:
                dist_op = None
                op_id = node.node.original_desc_id()
                for cur_op_id, cur_dist_op in self._dist_ops_for_program.items(
                ):
                    if op_id == cur_op_id \
                        or op_id == cur_dist_op.serial_op.desc.original_id():
                        dist_op = cur_dist_op
                assert dist_op is not None, \
                    "Operator must have a distributed operator after the initialization for program."
                serial_op_node_id = _node_id(node)
                new_dist_op = DistributedOperator(dist_op.serial_op,
                                                  dist_op.dist_attr)
                self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

731
    def copy_dist_attr_from_graph_to_program(self):
732
        assert self._is_initialized, \
733 734
            "Both program and graph must be initialized."
        updated_tensors = {}
735 736
        # all_nodes = self._serial_graph.all_nodes()
        all_nodes = self._serial_ordered_nodes
737 738
        for node in all_nodes:
            if node.is_var() and node.var() is not None:
739
                tensor_id = self._node_id_to_tensor_id[_node_id(node)]
740
                updated = updated_tensors.get(tensor_id, False)
741 742 743 744 745 746 747
                # If a var has multiples var nodes in graph, only use the first one for now
                if not updated:
                    tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph(
                        node)
                    dist_tensor_for_program = self._dist_tensors_for_program[
                        tensor_id]
                    dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph
748
                    updated_tensors[tensor_id] = True
749
            if node.is_op() and node.op() is not None:
750
                op_id = self._node_id_to_op_id[_node_id(node)]
751 752 753
                op_dist_attr_for_graph = self.get_op_dist_attr_for_graph(node)
                dist_op_for_program = self._dist_ops_for_program[op_id]
                dist_op_for_program.dist_attr = op_dist_attr_for_graph
754
        # TODO: the completion algorithm will skipped orphan tensors,
755 756 757
        # here we just set there process_mesh to the first one.
        for orphan_node in self._serial_orphan_tensor_nodes:
            serial_tensor_id = orphan_node.var().id()
758 759
            dist_tensor = self._dist_tensors_for_program.get(
                serial_tensor_id, None)
760 761 762 763 764 765 766
            if dist_tensor:
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
            else:
                serial_tensor_id = orphan_node.var().original_id()
                dist_tensor = self._dist_tensors_for_program.get(
                    serial_tensor_id, None)
                dist_tensor.dist_attr.process_mesh = self._process_meshes[0]
767 768 769 770 771

    def amend_dist_attr_for_program(self):
        for dist_tensor in self._dist_tensors_for_program.values():
            serial_tensor = dist_tensor.serial_tensor
            dist_attr = dist_tensor.dist_attr
772 773 774
            if serial_tensor.type == core.VarDesc.VarType.READER \
                or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
                or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES:
775 776 777 778 779
                tensor_shape = []
            else:
                tensor_shape = serial_tensor.shape
            dims_mapping = dist_attr.dims_mapping
            process_mesh_shape = dist_attr.process_mesh.topology
780
            process_mesh_processes = dist_attr.process_mesh.processes
781 782 783 784 785 786
            # If the dimension of tensor is less than the sharding dimension of process mesh,
            # we just amend the dimension mapping to -1. (Is this really OK?)
            for i in range(len(tensor_shape)):
                if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
                    and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
                    dims_mapping[i] = -1
787 788
                if dims_mapping[i] != -1 and len(process_mesh_processes) == 1:
                    dims_mapping[i] = -1
789 790 791 792

        for dist_op in self._dist_ops_for_program.values():
            serial_op = dist_op.serial_op
            dist_attr = dist_op.dist_attr
793 794
            process_mesh_shape = dist_attr.process_mesh.topology
            process_mesh_processes = dist_attr.process_mesh.processes
795 796 797 798 799
            for arg_name in serial_op.input_arg_names:
                if dist_op.get_serial_input(arg_name) is None:
                    tensor_shape = []
                else:
                    if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \
800
                        or dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
801 802 803 804 805 806 807 808 809 810 811
                        or dist_op.serial_op.type == "create_py_reader":
                        tensor_shape = []
                    else:
                        tensor_shape = dist_op.get_serial_input(arg_name).shape
                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
                    if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
                        and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
                        dims_mapping[i] = -1
812 813 814
                    if dims_mapping[i] != -1 and len(
                            process_mesh_processes) == 1:
                        dims_mapping[i] = -1
815
            for arg_name in serial_op.output_arg_names:
816 817 818
                if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \
                    or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \
                    or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.STEP_SCOPES:
819 820 821 822 823 824 825 826 827 828
                    tensor_shape = []
                else:
                    tensor_shape = dist_op.get_serial_output(arg_name).shape
                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                # If the dimension of tensor is less than the sharding dimension of process mesh,
                # we just amend the dimension mapping to -1. (Is this really OK?)
                for i in range(len(tensor_shape)):
                    if dims_mapping[i] != -1 and tensor_shape[i] > 0 \
                        and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]:
                        dims_mapping[i] = -1
829 830 831 832 833 834
                    if dims_mapping[i] != -1 and len(
                            process_mesh_processes) == 1:
                        dims_mapping[i] = -1
            if len(process_mesh_processes) == 1:
                dist_op.dist_attr.impl_type = "default"
                dist_op.dist_attr.impl_idx = 0
835 836

    def validate_dist_attr_for_program(self):
837
        if not self._is_initialized:
838 839
            assert False, \
                "Program must be initialized before validating its distributed attributes"
840
        for block in self.serial_main_program.blocks:
841 842
            for tensor in block.vars.values():
                dist_tensor = self.get_dist_tensor_for_program(tensor)
843 844 845
                assert dist_tensor is not None, \
                    "Tensor {} does not have a distributed attribute.".format(
                        dist_tensor.serial_tensor.name)
846 847
                if (dist_tensor
                        is not None) and (not dist_tensor.validate_dist_attr()):
848
                    assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format(
C
caozhou 已提交
849 850 851 852
                        dist_tensor.serial_tensor.name,
                        dist_tensor.serial_tensor.desc.id(),
                        dist_tensor.serial_tensor.desc.original_id(),
                        dist_tensor.dist_attr)
853 854
            for op in block.ops:
                dist_op = self.get_dist_op_for_program(op)
855 856 857
                assert dist_op is not None, \
                    "Operator {} does not have a distributed attribute.".format(
                        dist_op.serial_op.type)
858
                if (dist_op is not None) and (not dist_op.validate_dist_attr()):
859
                    assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format(
860
                        dist_op.serial_op.type, dist_op.serial_op.desc.id(),
861
                        dist_op.serial_op.desc.original_id(), dist_op.dist_attr)
862 863
        return True

Z
zhaoyingli 已提交
864 865 866 867 868
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
869 870 871 872 873
            if k in [
                "_original_serial_main_program", "_original_serial_startup_program", \
                "_serial_main_program", "_serial_startup_program", "_serial_graph", \
                "_dist_main_programs", "_dist_startup_programs", \
                "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \
874 875 876 877 878
                "_serial_ordered_op_nodes", "_original_serial_loss", \
                "_original_serial_feed_vars", "_original_serial_fetch_vars", \
                "_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_lr_optimizer", \
                "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \
                "_pass_context"]:
Z
zhaoyingli 已提交
879 880 881
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
882 883 884 885

        # update dist tensor's dist_context
        for key in result._dist_tensors_for_program.keys():
            result._dist_tensors_for_program[key]._dist_context = result
Z
zhaoyingli 已提交
886 887
        return result

888 889 890 891 892 893 894 895 896

class DistributedOperatorContext:
    """
    DistributedOperatorContext is used to create a dist op desc in Program.
    Every time to create a new dist op, the context should be updated for it accordingly.
    """

    def __init__(self):
        self._dst_main_program = None
897
        self._main_block = None
898
        self._dst_startup_program = None
899
        self._startup_block = None
900 901
        self._cur_src_op = None
        self._cur_dist_attr = None
902
        self.grad_op_id_to_op_id = {}
903
        self.grad_var_to_var = defaultdict(dict)
904
        self._work_block = None
905
        self.already_init_sync_vars = set()
906 907
        self.varname_mapping = None
        self.rank_id = None
908 909 910 911 912
        # NOTE Support correct parallelism for high-order differential model.
        # by default exceed_backward_init_op is False and it means we are in Forward phase; After exceed_backward_init_op = True,
        # it means we are in Backward phase.
        # And the final sulotion should be revise high-order differential logic for these two phases in future.
        self._exceed_backward_init_op = False
913

Z
zhaoyingli 已提交
914 915 916 917 918
    def __deepcopy__(self, memo):
        cls = self.__class__
        result = cls.__new__(cls)
        memo[id(self)] = result
        for k, v in self.__dict__.items():
919 920 921 922
            if k in [
                    "_dst_main_program", "_dst_startup_program", "_cur_src_op",
                    "_work_block", "_main_block", "_startup_block"
            ]:
Z
zhaoyingli 已提交
923 924 925 926 927
                setattr(result, k, v)
            else:
                setattr(result, k, copy.deepcopy(v, memo))
        return result

928 929
    @property
    def dst_main_program(self):
930 931
        return self._dst_main_program

932 933 934 935
    @dst_main_program.setter
    def dst_main_program(self, prog):
        self._dst_main_program = prog
        self._main_block = prog.blocks[0]
936

937 938 939
    @property
    def main_block(self):
        return self._main_block
940

941 942 943
    @property
    def dst_startup_program(self):
        return self._dst_startup_program
944

945 946 947 948
    @dst_startup_program.setter
    def dst_startup_program(self, prog):
        self._dst_startup_program = prog
        self._startup_block = prog.blocks[0]
949

950 951 952
    @property
    def startup_block(self):
        return self._startup_block
953

954 955 956 957
    @property
    def work_block(self):
        assert self._work_block is not None
        return self._work_block
958

959 960 961 962
    @work_block.setter
    def work_block(self, block):
        assert block is not None
        self._work_block = block
963

964 965 966
    @property
    def cur_src_op(self):
        assert self._cur_src_op is not None
967 968
        return self._cur_src_op

969 970 971
    def in_backward_phase(self):
        return self._exceed_backward_init_op

972
    def prepare_context(self, src_op):
973

974
        self._cur_src_op = src_op
975

976 977 978
        if is_loss_grad_op(src_op):
            self._exceed_backward_init_op = True

979 980 981 982 983
        # build input varname mapping
        kinputs = {}
        for input_name in src_op.desc.input_names():
            varnames = []
            for varname in src_op.desc.input(input_name):
984 985
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
986 987 988 989 990 991 992
            kinputs[input_name] = varnames

        # build output varname mapping
        koutputs = {}
        for output_name in src_op.desc.output_names():
            varnames = []
            for varname in src_op.desc.output(output_name):
993 994
                assert varname in self.varname_mapping
                varnames.append(self.varname_mapping[varname])
995 996 997
            koutputs[output_name] = varnames

        return kinputs, koutputs
998 999 1000


class BlockState(object):
1001

1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
    def __init__(self):
        self.nblock = 0
        self.forward_indices = []
        self.backward_indices = []
        self.backward_to_forward_index_map = {}

    def parse_forward_blocks(self, program):

        while program.current_block_idx != 0:
            program._rollback()

        assert program.current_block_idx == 0

        for idx, block in enumerate(program.blocks):

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx == -1, "forward_block_idx of forward block [{}] is not [{}]".format(
                idx, block.forward_block_idx)
            self.forward_indices.append(idx)
            self.nblock += 1

        assert self.nblock >= 1

    def parse_backward_blocks(self, program):

        assert 0 in self.forward_indices, "forward block idx are{}".format(
            self.forward_indices)
        self.backward_to_forward_index_map[0] = 0

        for idx, block in enumerate(program.blocks):

            if idx < len(self.forward_indices):
                continue

            assert idx == block.idx, "index doesn't match"
            assert block.forward_block_idx in self.forward_indices
            self.backward_indices.append(idx)
            self.backward_to_forward_index_map[idx] = block.forward_block_idx
            self.nblock += 1

        assert self.nblock == len(program.blocks)