completion.py 31.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
from copy import deepcopy
17
import time
18 19 20 21

from paddle.fluid import core
from paddle.fluid import framework

22
from .utils import print_program_with_dist_attr
23
from .operators import find_best_compatible_distributed_operator_impl
24 25 26 27 28
from .dist_context import get_default_distributed_context
from .dist_tensor import DistributedTensor
from .dist_op import DistributedOperator
from .dist_attribute import TensorDistributedAttribute
from .dist_attribute import OperatorDistributedAttribute
29
from paddle.distributed.fleet.meta_optimizers.common import OpRole
30 31


32 33 34 35
def compute_compatible_process_mesh(process_mesh_list):
    """Compute the compatible process mesh given a list of process meshes."""
    if not process_mesh_list:
        return None
36

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
    def _compute_compatible_process_mesh_two(pm1, pm2):
        if pm1 is None:
            return True, pm2
        if pm2 is None:
            return True, pm1
        if pm1 == pm2:
            return True, pm1
        if pm1.processes == pm2.processes:
            if len(pm1.topology) >= len(pm2.topology):
                return True, pm1
            else:
                return True, pm2
        process_set1 = set(pm1.processes)
        process_set2 = set(pm2.processes)
        if process_set1.issubset(process_set2):
            return True, pm2
        if process_set2.issubset(process_set1):
            return True, pm1
        return False, None

    compatible_result = None
    for process_mesh in process_mesh_list:
        compatible, compatible_result = _compute_compatible_process_mesh_two(
            compatible_result, process_mesh)
        if not compatible:
            return None
    return copy.deepcopy(compatible_result)


def compute_compatible_dim_mapping(dim_mapping_list):
    """Compute the compatible dim mapping given a list of dim mapping."""
    if not dim_mapping_list:
        return None
70

71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91
    def _compute_compatible_dim_mapping_two(dm1, dm2):
        if dm1 == -1:
            return True, dm2
        if dm2 == -1:
            return True, dm1
        if dm1 == dm2:
            return True, dm1
        return False, None

    compatible_result = -1
    for mapping in dim_mapping_list:
        compatible, compatible_result = _compute_compatible_dim_mapping_two(
            compatible_result, mapping)
        if not compatible:
            return None
    return compatible_result


def compute_compatible_dims_mapping(dims_mapping_list):
    """Compute the compatible dims mapping given a list of dims mapping.
       Each of dims mapping is also a list.
92
    """
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
    if not dims_mapping_list:
        return None
    length = len(dims_mapping_list[0])
    for dims_mapping in dims_mapping_list:
        if dims_mapping is None:
            return None
        if len(dims_mapping) != length:
            return None
    compatible_result = []
    for dim_mappings in zip(*dims_mapping_list):
        compatible_dim_mapping = compute_compatible_dim_mapping(
            list(dim_mappings))
        if compatible_dim_mapping is None:
            return None
        compatible_result.append(compatible_dim_mapping)
    return compatible_result


class Completer:
    def __init__(self, dist_context):
        assert dist_context is not None
        self._dist_context = dist_context

    def _update_tensor_node_dims_mapping(self, tensor_node, fwd=True):
        changed = False
        if (not tensor_node.is_var()) or (tensor_node.var() is None):
            return False
        tensor_desc = tensor_node.var()
        # Skip reader tensor
        if tensor_desc.type() == core.VarDesc.VarType.READER:
            return False
        tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
            tensor_node)
        assert tensor_dist_attr is not None
        if tensor_dist_attr.is_annotated("dims_mapping"):
            return False
        tensor_dims_mapping = tensor_dist_attr.dims_mapping
        if fwd:
            dims_mapping_list = []
            for pred_op_node in tensor_node.inputs:
                if pred_op_node.op() is not None:
                    if pred_op_node.op().type() == "create_py_reader" \
                        or pred_op_node.op().type() == "create_double_buffer_reader" \
                        or pred_op_node.op().type() == "read":
                        continue
                    op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
                        pred_op_node)
                    if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
                        op_dims_mapping = op_dist_attr.get_output_dims_mapping(
                            tensor_desc.name())
                        dims_mapping_list.append(op_dims_mapping)
            dims_mapping_list.append(tensor_dims_mapping)
            compatible_dims_mapping = compute_compatible_dims_mapping(
                dims_mapping_list)
            if (compatible_dims_mapping is not None) and \
                (compatible_dims_mapping != tensor_dims_mapping):
                tensor_dist_attr.dims_mapping = compatible_dims_mapping
150 151
                changed = True
        else:
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
            dims_mapping_list = []
            for succ_op_node in tensor_node.outputs:
                if succ_op_node.op() is not None:
                    if succ_op_node.op().type() == "create_py_reader" \
                        or succ_op_node.op().type() == "create_double_buffer_reader" \
                        or succ_op_node.op().type() == "read":
                        continue
                    op_dist_attr = self._dist_context.get_op_dist_attr_for_graph(
                        succ_op_node)
                    if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
                        op_dims_mapping = op_dist_attr.get_input_dims_mapping(
                            tensor_desc.name())
                        dims_mapping_list.append(op_dims_mapping)
            dims_mapping_list.append(tensor_dims_mapping)
            compatible_dims_mapping = compute_compatible_dims_mapping(
                dims_mapping_list)
            if (compatible_dims_mapping is not None) and \
                (compatible_dims_mapping != tensor_dims_mapping):
                tensor_dist_attr.dims_mapping = compatible_dims_mapping
171
                changed = True
172
        return changed
173

174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213
    def _update_op_node_dims_mapping(self, op_node, fwd=True):
        changed = False
        if (not op_node.is_op()) or (op_node.op() is None):
            return False
        # Skip reader op
        op_desc = op_node.op()
        if op_desc.type() == "create_py_reader" \
            or op_desc.type() == "create_double_buffer_reader" \
            or op_desc.type() == "read":
            return False
        dist_op = self._dist_context.get_dist_op_for_graph(op_node)
        op_dist_attr = dist_op.dist_attr
        if fwd:
            for tensor_node in op_node.inputs:
                if tensor_node.var() is not None:
                    if tensor_node.var().type() == core.VarDesc.VarType.READER:
                        continue
                    tensor_desc = tensor_node.var()
                    if op_dist_attr.is_annotated_input_dims_mapping(
                            tensor_desc.name()):
                        continue
                    tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
                        tensor_node)
                    if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
                        tensor_dims_mapping = tensor_dist_attr.dims_mapping
                        op_dims_mapping = op_dist_attr.get_input_dims_mapping(
                            tensor_desc.name())
                        compatible_dims_mapping = compute_compatible_dims_mapping(
                            [op_dims_mapping, tensor_dims_mapping])
                        if (compatible_dims_mapping is not None) and \
                            (compatible_dims_mapping != op_dims_mapping):
                            op_dist_attr.set_input_dims_mapping(
                                tensor_desc.name(), compatible_dims_mapping)
                            changed = True
            # Find the most compatible implemenetations from the distributed operator
            op_dist_impl = find_best_compatible_distributed_operator_impl(
                dist_op, fwd=True)
            assert op_dist_impl is not None, "Cannot find the dist op implementation."
            dim_changed = op_dist_impl.update_dims_mapping(dist_op)
            if dim_changed:
214
                changed = True
215 216 217 218 219 220
            if op_dist_impl.is_auto_compatible(dist_op):
                if op_dist_impl.type == "elementwise":
                    op_dist_attr.impl_type = "default"
                else:
                    op_dist_attr.impl_type = op_dist_impl.type
                op_dist_attr.impl_idx = op_dist_impl.idx
221
        else:
222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
            for tensor_node in op_node.outputs:
                if tensor_node.var() is not None:
                    if tensor_node.var().type() == core.VarDesc.VarType.READER:
                        continue
                    tensor_desc = tensor_node.var()
                    if op_dist_attr.is_annotated_output_dims_mapping(
                            tensor_desc.name()):
                        continue
                    tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph(
                        tensor_node)
                    if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh:
                        tensor_dims_mapping = tensor_dist_attr.dims_mapping
                        op_dims_mapping = op_dist_attr.get_output_dims_mapping(
                            tensor_desc.name())
                        compatible_dims_mapping = compute_compatible_dims_mapping(
                            [op_dims_mapping, tensor_dims_mapping])
                        if (compatible_dims_mapping is not None) and \
                            (compatible_dims_mapping != op_dims_mapping):
                            op_dist_attr.set_output_dims_mapping(
                                tensor_desc.name(), compatible_dims_mapping)
                            changed = True
            # Find the most compatible implemenetations from the distributed operator
            op_dist_impl = find_best_compatible_distributed_operator_impl(
                dist_op, fwd=False)
            assert op_dist_impl is not None, "Cannot find the dist op implementation."
            dim_changed = op_dist_impl.update_dims_mapping(dist_op)
            if dim_changed:
249
                changed = True
250 251 252 253 254 255 256
            if op_dist_impl.is_auto_compatible(dist_op):
                if op_dist_impl.type == "elementwise":
                    op_dist_attr.impl_type = "default"
                else:
                    op_dist_attr.impl_type = op_dist_impl.type
                op_dist_attr.impl_idx = op_dist_impl.idx
        return changed
257

258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298
    def _update_process_mesh(self):
        def _find_nearset_node(nodes, idx):
            for node in reversed(nodes[:idx]):
                node_dist_attr = self._dist_context.get_dist_attr_for_graph(
                    node)
                if node_dist_attr.process_mesh is not None:
                    return node

        total_reach_fix_point = False
        while not total_reach_fix_point:
            total_changed = False
            for is_fwd in [True, False]:
                all_nodes = self._dist_context.serial_ordered_nodes \
                    if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
                reach_fix_point = False
                while not reach_fix_point:
                    changed = False
                    for idx, node in enumerate(all_nodes):
                        nearest_node = _find_nearset_node(
                            self._dist_context.serial_ordered_nodes, idx)
                        if nearest_node is None:
                            continue
                        nearest_node_dis_attr = self._dist_context.get_dist_attr_for_graph(
                            nearest_node)
                        nearest_process_mesh = nearest_node_dis_attr.process_mesh
                        cur_node_dist_attr = self._dist_context.get_dist_attr_for_graph(
                            node)
                        cur_process_mesh = cur_node_dist_attr.process_mesh
                        compatible_process_mesh = compute_compatible_process_mesh(
                            [cur_process_mesh, nearest_process_mesh])
                        if compatible_process_mesh is not None \
                            and cur_process_mesh != compatible_process_mesh:
                            cur_node_dist_attr.process_mesh = compatible_process_mesh
                            changed = True
                    if changed:
                        reach_fix_point = False
                        total_changed = True
                    else:
                        reach_fix_point = True
            if total_changed:
                total_reach_fix_point = False
299
            else:
300
                total_reach_fix_point = True
301

302 303 304 305
    def _update_dims_mapping(self):
        # Complete dims_mapping for each node
        reach_fix_point = False
        while not reach_fix_point:
306
            changed = False
307 308 309 310 311 312 313 314 315 316 317 318 319 320
            for is_fwd in [True, False]:
                all_nodes = self._dist_context.serial_ordered_nodes \
                    if is_fwd else reversed(self._dist_context.serial_ordered_nodes)
                for node in all_nodes:
                    if node.is_var() and node.var() is not None:
                        tensor_changed = self._update_tensor_node_dims_mapping(
                            node, fwd=is_fwd)
                        if tensor_changed:
                            changed = True
                    if node.is_op() and node.op() is not None:
                        op_changed = self._update_op_node_dims_mapping(
                            node, fwd=is_fwd)
                        if op_changed:
                            changed = True
321
            if changed:
322
                reach_fix_point = False
323
            else:
324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
                reach_fix_point = True

    def complete_forward_annotation(self, serial_main_program):
        """ Complete annotation for the partial annotated serial_main_program.
        Arguments:
            serial_main_program: partial annotated serial_main_program.
        Returns:
            serial_main_program: completed annotated serial_main_program.
        """

        # Use the default distribted context for completeion if there is no one
        self._dist_context.serial_program = serial_main_program

        # Initialize distributed attributes for all var and op node in serial_main_program
        self._dist_context.init_dist_attr_for_program()

        # Initialize distributed attributes for all var and op node in graph
        self._dist_context.init_dist_attr_for_graph()

        self._update_process_mesh()

        # Complete dims_mapping for each node
        self._update_dims_mapping()

        # Copy the corresponding distributed attribute from graph to serial_main_program
        self._dist_context.copy_dist_attr_from_graph_to_program()
        self._dist_context.clear_dist_info_for_graph()

        # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
        # Do the validation check and amend some completion
        self._dist_context.amend_dist_attr_for_program()

        # print_serial_main_program_with_dist_attr(serial_main_program, self._dist_context)
        self._dist_context.validate_dist_attr_for_program()

        return serial_main_program

    def complete_backward_annotation(self, serial_main_program):
        """Complete the annotation of vars and ops in the backward phase for parallel program."""

        def _is_grad_var_name(name):
            if "@GRAD" in name:
                return True
            return False

        def _get_forward_varname_from_grad_varname(grad_var_name):
            assert _is_grad_var_name(
                grad_var_name), "[{}] is not a grad varnme.".format(
                    grad_var_name)
            return grad_var_name[:grad_var_name.find("@GRAD")]

        def _get_op_by_id(ops, id):
            for op in ops:
                if op.desc.id() == id:
                    return op
            return None

        first_backward_op_idx = -1
        for idx, op in enumerate(serial_main_program.global_block().ops):
            if int(op.attr('op_role')) == int(
                    int(core.op_proto_and_checker_maker.OpRole.Backward) | int(
                        core.op_proto_and_checker_maker.OpRole.Loss)):
                assert op.type == "fill_constant"
                first_backward_op_idx = idx
                break

        assert first_backward_op_idx >= 0, "No backward procedure found in this program."

        ops = list(serial_main_program.global_block().ops)
        vars = serial_main_program.global_block().vars
        dist_op_context = self._dist_context.dist_op_context

        for idx in range(first_backward_op_idx, len(ops)):

            # complete the initial grad loss op
            if idx == first_backward_op_idx:
                assert ops[idx].type == "fill_constant"
                assert len(
                    ops[idx].input_arg_names
                ) == 0, "first backward op should has only ONE output, but got [{}]".format(
                    len(ops[idx].input_arg_names))
                assert len(
                    ops[idx].output_arg_names
                ) == 1, "first backward op should has only ONE output, but got [{}]".format(
                    len(ops[idx].output_arg_names))

                grad_var = vars[ops[idx].output_arg_names[0]]
                forward_var_name = _get_forward_varname_from_grad_varname(
                    grad_var.name)
                forward_var = vars[forward_var_name]

                # TODO complete other attribte for grad var
                tensor_dist_attr = TensorDistributedAttribute()
                process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
                    forward_var).process_mesh
                dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
                    forward_var).dims_mapping
                tensor_dist_attr.dims_mapping = dims_mapping
                tensor_dist_attr.process_mesh = process_mesh
                self._dist_context.set_tensor_dist_attr_for_program(
                    grad_var, tensor_dist_attr)
425

426 427 428 429 430 431 432
                op_dist_attr = OperatorDistributedAttribute()
                op_dist_attr.process_mesh = process_mesh
                op_dist_attr.set_output_dims_mapping(grad_var.name,
                                                     dims_mapping)
                self._dist_context.set_op_dist_attr_for_program(ops[idx],
                                                                op_dist_attr)
                continue
433

434 435 436 437 438 439 440 441 442 443
            # complete the annotation of grad op (xxx_grad op or sum op)
            # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
            grad_op = ops[idx]
            if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
                # TODO support the case where one forward op corresponding to multiple xxx_grad op
                forward_op = _get_op_by_id(
                    ops[:first_backward_op_idx],
                    dist_op_context.grad_op_id_to_op_id[grad_op.desc.id()])
                assert forward_op is not None

J
JZ-LIANG 已提交
444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470
                if grad_op.type == "concat" and forward_op.type == "split":
                    forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
                        forward_op)
                    output_var = vars[grad_op.desc.output('Out')[0]]
                    split_input_var_name = forward_op.input("X")[0]
                    ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
                        split_input_var_name)
                    ref_mesh = forward_op_dist_attr.process_mesh

                    grad_op_dist_attr = OperatorDistributedAttribute()
                    for input_name in grad_op.input_arg_names:
                        grad_op_dist_attr.set_input_dims_mapping(
                            input_name, ref_dims_mapping)

                    output_var_dist_attr = TensorDistributedAttribute()
                    output_var_dist_attr.dims_mapping = ref_dims_mapping
                    output_var_dist_attr.process_mesh = ref_mesh
                    dist_context.set_tensor_dist_attr_for_program(
                        output_var, output_var_dist_attr)

                    grad_op_dist_attr.set_output_dims_mapping(output_var.name,
                                                              ref_dims_mapping)
                    grad_op_dist_attr.process_mesh = ref_mesh
                    dist_context.set_op_dist_attr_for_program(grad_op,
                                                              grad_op_dist_attr)
                    continue

471 472 473 474 475 476 477 478 479 480 481 482 483
                # op dist attr
                forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
                    forward_op)
                forward_op_process_mesh = forward_op_dist_attr.process_mesh
                grad_op_dist_attr = OperatorDistributedAttribute()
                grad_op_dist_attr.process_mesh = forward_op_process_mesh

                # var
                for input_name in grad_op.input_arg_names:
                    input_var = vars[input_name]
                    ref_dims_mapping = None
                    if "@GRAD" in input_name:
                        forward_name = _get_forward_varname_from_grad_varname(
Z
zhaoyingli 已提交
484 485
                            input_name)
                        ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
486 487 488 489 490 491 492 493 494 495 496 497 498 499
                            forward_name)
                    else:
                        if forward_op_dist_attr.get_input_dims_mapping(
                                input_name):
                            ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
                                input_name)
                        else:
                            ref_dims_mapping = forward_op_dist_attr.get_output_dims_mapping(
                                input_name)

                    assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format(
                        input_var.name)
                    grad_op_dist_attr.set_input_dims_mapping(input_name,
                                                             ref_dims_mapping)
500

501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
                for output_name in grad_op.desc.output_names():
                    assert len(grad_op.desc.output(output_name)) in [0, 1]
                    if _is_grad_var_name(output_name):
                        input_name = _get_forward_varname_from_grad_varname(
                            output_name)
                    else:
                        assert grad_op.type in [
                            "cast", "c_identity", "c_allreduce_sum"
                        ]
                        input_name = "X"
                    assert input_name in forward_op.desc.input_names(
                    ), "var [{}] in op [{}]'s output but could not find [{}] in its forward op".format(
                        output_name, grad_op.type, input_name)
                    if len(grad_op.desc.output(output_name)) == 1:
                        # tensor dist attr
                        output_var = vars[grad_op.desc.output(output_name)[0]]
                        forward_name = _get_forward_varname_from_grad_varname(
                            output_var.name)
                        ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
                            forward_name)
521

522 523 524 525 526
                        output_var_dist_attr = TensorDistributedAttribute()
                        output_var_dist_attr.dims_mapping = ref_dims_mapping
                        output_var_dist_attr.process_mesh = forward_op_process_mesh
                        self._dist_context.set_tensor_dist_attr_for_program(
                            output_var, output_var_dist_attr)
527

528 529
                        grad_op_dist_attr.set_output_dims_mapping(
                            output_var.name, ref_dims_mapping)
530

531 532
                self._dist_context.set_op_dist_attr_for_program(
                    grad_op, grad_op_dist_attr)
533

534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654
            # only sum op for merge mutiple version grad has no a corresponding mapping in grad_op_id_to_op_id
            else:
                assert grad_op.type == "sum", "got unexpect op [{}]".format(
                    str(grad_op.type))
                assert all(map(_is_grad_var_name, grad_op.input_arg_names))
                assert len(grad_op.output_arg_names) == 1

                ref_forward_var_name = _get_forward_varname_from_grad_varname(
                    grad_op.output_arg_names[0])
                forward_var = vars[ref_forward_var_name]
                ref_forward_var_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
                    forward_var).dims_mapping
                ref_forward_var_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
                    forward_var).process_mesh

                # output
                tensor_dist_attr = TensorDistributedAttribute()
                tensor_dist_attr.dims_mapping = ref_forward_var_dims_mapping
                tensor_dist_attr.process_mesh = ref_forward_var_process_mesh
                self._dist_context.set_tensor_dist_attr_for_program(
                    vars[grad_op.output_arg_names[0]], tensor_dist_attr)

                # op
                grad_op_dist_attr = OperatorDistributedAttribute()
                grad_op_dist_attr.process_mesh = ref_forward_var_process_mesh
                for var_name in grad_op.input_arg_names:
                    assert _get_forward_varname_from_grad_varname(
                        var_name) == ref_forward_var_name
                    grad_op_dist_attr.set_input_dims_mapping(
                        var_name, ref_forward_var_dims_mapping)

                grad_op_dist_attr.set_output_dims_mapping(
                    grad_op.output_arg_names[0], ref_forward_var_dims_mapping)
                self._dist_context.set_op_dist_attr_for_program(
                    grad_op, grad_op_dist_attr)

    def complete_update_annotation(self, serial_main_program):
        """Complete the annotation of vars and ops in the update phase for parallel program."""
        ops = list(serial_main_program.global_block().ops)
        vars = serial_main_program.global_block().vars
        learning_rate_completed = False

        for idx in range(len(ops)):

            # complete the annotation of the optimizer op.
            # TODO to add attribute for moment var
            op = ops[idx]
            if int(op.attr('op_role')) == int(OpRole.Optimize):

                if "Grad" in op.input_names and "Param" in ops[idx].input_names:
                    assert len(op.input(
                        "Param")) == 1, "Only support one-to-one now."
                    assert len(op.input(
                        "Grad")) == 1, "Only support one-to-one now."
                    param = vars[op.input("Param")[0]]
                    grad_var = vars[op.input("Grad")[0]]

                    param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program(
                        param)
                    assert param_dist_attr is not None
                    ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program(
                        param).process_mesh
                    assert ref_process_mesh is not None
                    ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program(
                        param).dims_mapping
                    assert ref_dims_mapping is not None
                    op_dist_attr = OperatorDistributedAttribute()
                    op_dist_attr.process_mesh = ref_process_mesh
                    op_dist_attr.set_input_dims_mapping(grad_var.name,
                                                        ref_dims_mapping)
                    op_dist_attr.set_input_dims_mapping(param.name,
                                                        ref_dims_mapping)
                    op_dist_attr.set_output_dims_mapping(param.name,
                                                         ref_dims_mapping)
                    learning_var = vars[op.input("LearningRate")[0]]
                    op_dist_attr.set_input_dims_mapping(learning_var.name, [-1])
                    op_dist_attr.set_output_dims_mapping(learning_var.name,
                                                         [-1])

                    if not learning_rate_completed:
                        learning_rate_completed = True
                        var_dist_attr = TensorDistributedAttribute()
                        var_dist_attr.process_mesh = ref_process_mesh
                        var_dist_attr.dims_mapping = [-1]
                        self._dist_context.set_tensor_dist_attr_for_program(
                            learning_var, var_dist_attr)

                    for input_name in op.desc.input_names():

                        if input_name in [
                                'Param', 'Grad', 'LearningRate', "SkipUpdate",
                                "Beta1Tensor", "Beta2Tensor", "EpsilonTensor",
                                "MasterParam"
                        ]:
                            continue

                        assert len(op.desc.input(input_name)) == 1
                        input_var = vars[op.desc.input(input_name)[0]]
                        input_var_attr = TensorDistributedAttribute()

                        if "Beta1Pow" in input_name or "Beta2Pow" in input_name:
                            input_var_attr.dims_mapping = [-1]
                            op_dist_attr.set_input_dims_mapping(input_var.name,
                                                                [-1])
                            op_dist_attr.set_output_dims_mapping(input_var.name,
                                                                 [-1])
                        else:
                            assert "Moment" in input_name
                            input_var_attr.dims_mapping = ref_dims_mapping
                            op_dist_attr.set_input_dims_mapping(
                                input_var.name, ref_dims_mapping)
                            op_dist_attr.set_output_dims_mapping(
                                input_var.name, ref_dims_mapping)

                        input_var_attr.process_mesh = ref_process_mesh
                        self._dist_context.set_tensor_dist_attr_for_program(
                            input_var, input_var_attr)

                    self._dist_context.set_op_dist_attr_for_program(
                        op, op_dist_attr)
                    continue