rule_based_tuner.py 86.6 KB
Newer Older
C
caozhou 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
import copy
import logging
17
import math
18
import os
19 20 21
import pickle
import sys
import time
22
from abc import abstractmethod
23
from collections import OrderedDict
24 25 26
from functools import reduce

import numpy as np
27 28

import paddle
29
from paddle.distributed.auto_parallel.cluster_v2 import DeviceMesh
30
from paddle.distributed.auto_parallel.completion import Completer
31
from paddle.distributed.auto_parallel.cost import CostEstimator
32 33 34 35 36 37
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistAttr,
    TensorDistAttr,
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.dist_tensor import DistributedTensor
38 39 40 41 42 43
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.utils import (
    is_gradient_clip_op,
    print_program_with_dist_attr,
)
from paddle.distributed.fleet.meta_optimizers.common import OpRole
44 45 46 47 48
from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import Parameter, unique_name

from ...utils.log_utils import get_logger
49 50 51 52 53 54 55 56
from ..graph import Graph

_PATTERNS = {}


def register_pattern(cls):
    """Register pattern for rule-based tuner."""

57
    def register():
58
        global _PATTERNS
59 60 61 62 63 64 65 66 67 68 69
        pattern = cls()
        _PATTERNS[pattern.name] = pattern
        # sort patterns according to the number of sharded tensors
        # set its dist attr by the fisrt one when a tensor can be matched by multiple patterns.
        _PATTERNS = dict(
            sorted(
                _PATTERNS.items(), key=lambda x: -x[1].attrs["sharded_tensors"]
            )
        )

    register()
70 71 72 73

    return cls


74
class BasePattern(Graph):
75 76 77 78 79 80 81 82
    """
    Base class of pattern.
    The BasePattern inherits the Graph, two important differences are shard_spec and sharded_tensors.
    For shard_spec, it indicates the shard specification of tensor node in this pattern under different parallelism.
    For sharded_tensors, it represents the number of tensors which sharded.
    """

    _name = "base"
83 84

    def __init__(self):
85
        """Every pattern has its own name and build method."""
86 87 88
        super().__init__()
        self.build()

89 90 91 92
    @property
    def name(self):
        return self.__class__._name

93 94 95 96 97 98 99
    @abstractmethod
    def build(self):
        pass


@register_pattern
class QKVPattern(BasePattern):
100 101
    """The QKV pattern defined by GPT model in PaddleFleetX."""

102 103 104 105 106 107 108 109
    name = "qkv"

    def __init__(self):
        super().__init__()

    def build(self):
        query = self.add_node(0, **{"type": "var"})

110
        # define q, k, v weight
111 112 113
        q_weight = self.add_node(1, **{"dim": 2, "type": "param"})
        k_weight = self.add_node(2, **{"dim": 2, "type": "param"})
        v_weight = self.add_node(3, **{"dim": 2, "type": "param"})
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137
        # define q, k, v matmul_v2
        q_matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})
        k_matmul_v2 = self.add_node(5, **{"type": "matmul_v2"})
        v_matmul_v2 = self.add_node(6, **{"type": "matmul_v2"})
        # define input edge
        q_x_edge = self.add_edge(
            query.id, q_matmul_v2.id, **{"input_name": "X"}
        )
        k_x_edge = self.add_edge(
            query.id, k_matmul_v2.id, **{"input_name": "X"}
        )
        v_x_edge = self.add_edge(
            query.id, v_matmul_v2.id, **{"input_name": "X"}
        )
        q_y_edge = self.add_edge(
            q_weight.id, q_matmul_v2.id, **{"input_name": "Y"}
        )
        k_y_edge = self.add_edge(
            k_weight.id, k_matmul_v2.id, **{"input_name": "Y"}
        )
        v_y_edge = self.add_edge(
            v_weight.id, v_matmul_v2.id, **{"input_name": "Y"}
        )
        # define q, k, v matmul_v2 output
138 139 140 141
        q = self.add_node(7, **{"type": "var"})
        k = self.add_node(8, **{"type": "var"})
        v = self.add_node(9, **{"type": "var"})

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377
        # define output edge
        q_out_edge = self.add_edge(
            q_matmul_v2.id, q.id, **{"output_name": "Out"}
        )
        k_out_edge = self.add_edge(
            k_matmul_v2.id, k.id, **{"output_name": "Out"}
        )
        v_out_edge = self.add_edge(
            v_matmul_v2.id, v.id, **{"output_name": "Out"}
        )

        # define shard_spec
        shard_spec = {
            "dp_mp": {
                0: [0, -1, -1],
                1: [-1, 1],
                2: [-1, 1],
                3: [-1, 1],
            },
            "mp_dp": {
                0: [1, -1, -1],
                1: [-1, 0],
                2: [-1, 0],
                3: [-1, 0],
            },
            "mp": {0: [-1, -1, -1], 1: [-1, 0], 2: [-1, 0], 3: [-1, 0]},
            "dp": {
                0: [0, -1, -1],
                1: [-1, -1],
                2: [-1, -1],
                3: [-1, -1],
            },
        }
        self.attrs["shard_spec"] = shard_spec
        # define sharded_tensors
        self.attrs["sharded_tensors"] = 4


@register_pattern
class RowMatmulPattern(BasePattern):
    """Row matmul pattern defined by GPT model in PaddleFleetX."""

    name = "row_matmul"

    def __init__(self):
        super().__init__()

    def build(self):
        # define reshape input
        input = self.add_node(0, **{"type": "var"})

        # define reshape
        reshape = self.add_node(1, **{"type": "reshape2"})

        # define reshape input egde
        x_edge = self.add_edge(input.id, reshape.id, **{"input_name": "X"})

        # define reshape out
        output = self.add_node(2, **{"type": "var"})

        # define reshape output edge
        out_edge = self.add_edge(
            reshape.id, output.id, **{"output_name": "Out"}
        )

        # define matmul_v2 weight
        weight = self.add_node(3, **{"dim": 2, "type": "param"})

        # define matmul_v2
        matmul_v2 = self.add_node(4, **{"type": "matmul_v2"})

        # define input edge
        x_edge = self.add_edge(output.id, matmul_v2.id, **{"input_name": "X"})
        y_edge = self.add_edge(weight.id, matmul_v2.id, **{"input_name": "Y"})

        # define q, k, v matmul_v2 output
        output = self.add_node(5, **{"type": "var"})

        # define output edge
        out_edge = self.add_edge(
            matmul_v2.id, output.id, **{"output_name": "Out"}
        )

        # define shard_spec
        shard_spec = {
            "dp_mp": {
                3: [1, -1],
            },
            "mp_dp": {
                3: [0, -1],
            },
            "mp": {3: [0, -1]},
            "dp": {
                3: [-1, -1],
            },
        }
        self.attrs["shard_spec"] = shard_spec

        # define sharded_tensors
        self.attrs["sharded_tensors"] = 1


@register_pattern
class FFNPattrern(BasePattern):
    """FFN pattern defined by GPT model in PaddleFleetX."""

    name = "ffn"

    def __init__(self):
        super().__init__()

    def build(self):
        x = self.add_node(0, **{"type": "var"})

        w1_weight = self.add_node(1, **{"dim": 2, "type": "param"})
        w1_matmul = self.add_node(2, **{"type": "matmul_v2"})

        w1_x = self.add_edge(0, 2, **{"input_name": "X"})
        w1_y = self.add_edge(1, 2, **{"input_name": "Y"})

        out1 = self.add_node(3, **{"type": "var"})
        w1_out = self.add_edge(2, 3, **{"output_name": "Out"})

        w1_b = self.add_node(4, **{"dim": 1, "type": "param"})
        add1 = self.add_node(5, **{"type": "elementwise_add"})

        add1_x = self.add_edge(3, 5, **{"input_name": "X"})
        add1_y = self.add_edge(4, 5, **{"input_name": "Y"})

        out2 = self.add_node(6, **{"type": "var"})
        add1_out = self.add_edge(5, 6, **{"output_name": "Out"})

        gelu = self.add_node(7, **{"type": "gelu"})

        gelu_x = self.add_edge(6, 7, **{"input_name": "X"})
        out3 = self.add_node(8, **{"type": "var"})
        gelu_out = self.add_edge(7, 8, **{"output_name": "Out"})

        w2_weight = self.add_node(9, **{"dim": 2, "type": "param"})
        w2_matmul = self.add_node(10, **{"type": "matmul_v2"})

        w1_x = self.add_edge(8, 10, **{"input_name": "X"})
        w1_y = self.add_edge(9, 10, **{"input_name": "Y"})

        out4 = self.add_node(11, **{"type": "var"})
        w2_out = self.add_edge(10, 11, **{"output_name": "Out"})

        w2_b = self.add_node(12, **{"dim": 1, "type": "param"})
        add2 = self.add_node(13, **{"type": "elementwise_add"})

        add2_x = self.add_edge(11, 13, **{"input_name": "X"})
        add2_y = self.add_edge(12, 13, **{"input_name": "Y"})

        out5 = self.add_node(14, **{"type": "var"})
        add2_out = self.add_edge(13, 14, **{"output_name": "Out"})

        # define shard_spec
        shard_spec = {
            "dp_mp": {0: [0, -1, -1], 1: [-1, 1], 9: [1, -1]},
            "mp_dp": {0: [1, -1, -1], 1: [-1, 0], 9: [0, -1]},
            "mp": {1: [-1, 0], 9: [0, -1]},
            "dp": {0: [0, -1, -1], 1: [-1, -1], 9: [-1, -1]},
        }
        self.attrs["shard_spec"] = shard_spec

        # define sharded_tensors
        self.attrs["sharded_tensors"] = 2


@register_pattern
class SharedWordEmbeddingPattern(BasePattern):
    """Sharded word embedding pattern defined by GPT model in PaddleFleetX."""

    name = "shared_word_embedding"

    def __init__(self):
        super().__init__()

    def build(self):
        # define embedding input
        tokens = self.add_node(0, **{"type": "data"})
        word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})

        # define embedding
        embedding = self.add_node(2, **{"type": "lookup_table_v2"})

        # define embedding input edge
        ids = self.add_edge(0, 2, **{"input_name": "Ids"})
        w = self.add_edge(1, 2, **{"input_name": "W"})

        # define embedding output
        out = self.add_node(3, **{"type": "var"})

        # define embedding output edge
        out_edge = self.add_edge(2, 3, **{"output_name": "Out"})

        # define matmul_v2 input
        x = self.add_node(4, **{"type": "var"})

        # define matmul_v2
        matmul = self.add_node(5, **{"type": "matmul_v2"})

        # define matmul_v2 input edge
        x_edge = self.add_edge(4, 5, **{"input_name": "X"})
        y_edge = self.add_edge(1, 5, **{"input_name": "Y"})

        # define matmul_v2 output
        out = self.add_node(6, **{"type": "var"})

        # define matmul_v2 output edge
        out_edge = self.add_edge(5, 6, **{"output_name": "Out"})

        # define shard_spec
        shard_spec = {
            "dp_mp": {0: [0, -1], 1: [1, -1], 4: [0, -1, -1]},
            "mp_dp": {0: [1, -1], 1: [0, -1], 4: [1, -1, -1]},
            "mp": {0: [-1, -1], 1: [0, -1], 4: [-1, -1, -1]},
            "dp": {0: [0, -1], 1: [-1, -1], 4: [0, -1, -1]},
        }
        self.attrs["shard_spec"] = shard_spec
        self.attrs["sharded_tensors"] = 3


@register_pattern
class PositionEmbeddingPattern(BasePattern):
    """Position embedding pattern defined by GPT model in PaddleFleetX."""

    name = "position_embedding"

    def __init__(self):
        super().__init__()

    def build(self):
        # define embedding input
        tokens = self.add_node(0, **{"type": "data"})
        word_embeddings = self.add_node(1, **{"dim": 2, "type": "param"})
378

379 380
        # define embedding
        embedding = self.add_node(2, **{"type": "lookup_table_v2"})
381

382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
        # define embedding input edge
        ids = self.add_edge(0, 2, **{"input_name": "Ids"})
        w = self.add_edge(1, 2, **{"input_name": "W"})

        # define embedding output
        out = self.add_node(3, **{"type": "var"})

        # define embedding output edge
        out_edge = self.add_edge(2, 3, **{"output_name": "Out"})

        # define shard_spec
        shard_spec = {
            "dp_mp": {0: [0, -1], 1: [-1, -1], 3: [-1, -1, -1]},
            "mp_dp": {0: [1, -1], 1: [-1, -1], 3: [1, -1, -1]},
            "mp": {0: [-1, -1], 1: [-1, -1], 3: [-1, -1, -1]},
            "dp": {0: [0, -1], 1: [-1, -1], 3: [0, -1, -1]},
        }
        self.attrs["shard_spec"] = shard_spec

        # define sharded_tensors
        self.attrs["sharded_tensors"] = 1


@register_pattern
class UnsqueezeDataPattern(BasePattern):
    """Unsqueeze data pattern defined by GPT model in the PaddleFleetX."""

    name = "unsqueeze_data"

    def __init__(self):
        super().__init__()

    def build(self):
        # define unsequeeze input
        tokens = self.add_node(0, **{"type": "data"})
        # define unsequeeze
        unsqueeze = self.add_node(1, **{"type": "unsqueeze2"})
        # define unsequeeze input edge
        x_edge = self.add_edge(0, 1, **{"input_name": "X"})
        # pattern: pure mp or hybrid dp+mp
        shard_spec = {
            "dp_mp": {0: [0, -1]},
            "mp_dp": {0: [1, -1]},
            "mp": {0: [-1, -1]},
            "dp": {0: [0, -1]},
        }
        self.attrs["shard_spec"] = shard_spec
        self.attrs["sharded_tensors"] = 1


@register_pattern
class ReshapeDataPattern(BasePattern):
    """Reshape data pattern defined by GPT model in PaddleFleetX."""

    name = "reshape_data"

    def __init__(self):
        super().__init__()

    def build(self):
        # define unsequeeze input
        data = self.add_node(0, **{"type": "data"})

        # define unsequeeze
        reshape = self.add_node(1, **{"type": "reshape2"})

        # define unsequeeze input edge
        x_edge = self.add_edge(0, 1, **{"input_name": "X"})

        # define shard_spec
        shard_spec = {
            "dp_mp": {0: [0, -1]},
            "mp_dp": {0: [1, -1]},
            "mp": {0: [-1, -1]},
            "dp": {0: [0, -1]},
        }
        self.attrs["shard_spec"] = shard_spec

        # define sharded_tensors
        self.attrs["sharded_tensors"] = 1


class GraphUtil:
    """Graph util is used to convert ops to graph or match pattern for graph."""

    @staticmethod
    def convert_to_graph(block):
        """Convert ops to graph."""
        graph = Graph()
        graph.attrs["var_to_id"] = {}  # {var_name: node_id}
        graph.attrs["id_to_var_desc_id"] = {}  # {node_id: var_desc_id}
        graph.attrs["id_to_var_name"] = {}
        graph.attrs["op_to_id"] = {}  # {op_id: node_id}
        graph.attrs["id_to_op"] = {}  # {node_id: op}

        ops = block.ops
        node_id = -1
        for op in ops:
            attrs = op.all_attrs()
            attrs["type"] = op.type
            node_id += 1

            # create op node
            op_node = graph.add_node(node_id, **attrs)
            graph.attrs["op_to_id"][op.desc.id()] = op_node.id
            graph.attrs["id_to_op"][op_node.id] = op
            graph._attr_to_nodes[op_node.id] = {}
            for input_name in op.input_names:
                graph._attr_to_nodes[op_node.id][input_name] = []
                for var_name in op.input(input_name):
492 493 494 495 496 497 498
                    if var_name not in graph.attrs["var_to_id"]:
                        # create var node
                        node_id += 1
                        var_node = graph.add_node(node_id)
                        var = block._var_recursive(var_name)
                        if var.is_parameter:
                            var_node.attrs["type"] = "param"
499 500 501 502
                            var_node.attrs["dim"] = len(var.shape)
                        elif var.is_data:
                            var_node.attrs["type"] = "data"
                            var_node.attrs["dim"] = len(var.shape)
503 504 505
                        else:
                            var_node.attrs["type"] = "var"
                        graph.attrs["var_to_id"][var_name] = var_node.id
506 507 508 509
                        graph.attrs["id_to_var_desc_id"][
                            var_node.id
                        ] = var.desc.original_id()
                        graph.attrs["id_to_var_name"][var_node.id] = var_name
510 511 512 513
                    else:
                        var_node_id = graph.attrs["var_to_id"][var_name]
                        var_node = graph._nodes[var_node_id]

514 515 516 517
                    # create edge that input -> op
                    input_edge = graph.add_edge(var_node.id, op_node.id)
                    input_edge.attrs["input_name"] = input_name
                    graph._attr_to_nodes[op_node.id][input_name].append(
518 519 520
                        var_node
                    )

521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
                for output_name in op.output_names:
                    graph._attr_to_nodes[op_node.id][output_name] = []
                    for var_name in op.output(output_name):
                        if var_name not in graph.attrs["var_to_id"]:
                            # create var node
                            node_id += 1
                            var_node = graph.add_node(node_id)
                            var = block._var_recursive(var_name)
                            if var.is_parameter:
                                var_node.attrs["type"] = "param"
                            else:
                                var_node.attrs["type"] = "var"
                            graph.attrs["var_to_id"][var_name] = var_node.id
                            graph.attrs["id_to_var_desc_id"][
                                var_node.id
                            ] = var.desc.original_id()
                            graph.attrs["id_to_var_name"][
                                var_node.id
                            ] = var_name
                        else:
                            var_node_id = graph.attrs["var_to_id"][var_name]
                            var_node = graph._nodes[var_node_id]
543

544 545 546
                        # create edge that op -> output
                        output_edge = graph.add_edge(op_node.id, var_node.id)
                        output_edge.attrs["output_name"] = output_name
547

548 549 550
                        graph._attr_to_nodes[op_node.id][output_name].append(
                            var_node
                        )
551

552
        return graph
553

554 555 556 557 558 559
    @staticmethod
    def match_pattern(pattern, graph):
        def _is_op_node(node):
            """Judge whether node is op node."""
            if node.attrs["type"] not in ["var", "param", "data"]:
                return True
560

561
            return False
562

563 564 565
        def _compare_op_node(src, tgt):
            """Compare whether two op nodes are equivalent."""
            if src.attrs["type"] != tgt.attrs["type"]:
566
                return False
567

568
            return True
569

570 571 572 573 574 575 576
        def _compare_var_node(src, tgt):
            """Compare whether two var nodes are equivalent."""
            for key in src.attrs:
                if key not in tgt.attrs:
                    return False
                if src.attrs[key] != tgt.attrs[key]:
                    return False
577

578
            return True
579

580 581
        def _match_core(src_node, tgt_node):
            nonlocal not_matched
582

583 584 585
            # not support one input name or output name corresponding to multiple vars
            if not_matched:
                return
586

587 588 589
            if _is_op_node(src_node):
                # compare op node whether equal
                if not _compare_op_node(src_node, tgt_node):
590 591
                    not_matched = True
                    return
592

593
                result[src_node.id] = tgt_node.id
594

595 596 597 598 599 600 601 602
                # input var nodes
                src_input_nodes = src_reverse_adjs[src_node.id]
                for node in src_input_nodes:
                    # has visited
                    if node.id in result:
                        continue
                    edge = src_edges[node.id][src_node.id]
                    input_name = edge.attrs["input_name"]
603

604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630
                    # NOTE: do not support one input name or output name corresponding to multiple vars
                    compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
                        input_name, None
                    )
                    if not compare_nodes:
                        not_matched = True
                        return
                    _match_core(node, compare_nodes[0])

                # output var nodes
                src_output_node_ids = src_edges[src_node.id].keys()
                for node_id in src_output_node_ids:
                    # has visited
                    if node_id in result:
                        continue
                    node = src_nodes[node_id]
                    edge = src_edges[src_node.id][node_id]
                    output_name = edge.attrs["output_name"]

                    # NOTE: do not support one input name or output name corresponding to multiple vars
                    compare_nodes = tgt_attr_to_nodes[tgt_node.id].get(
                        output_name, None
                    )
                    if not compare_nodes:
                        not_matched = True
                        return
                    _match_core(node, compare_nodes[0])
631

632 633 634
            else:
                # compare var nodes whether equal
                if not _compare_var_node(src_node, tgt_node):
635 636 637
                    not_matched = True
                    return

638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658
                result[src_node.id] = tgt_node.id

                # as input for op node
                src_as_input_node_ids = src_edges[src_node.id].keys()
                for node_id in src_as_input_node_ids:
                    if node_id in result:
                        continue

                    src_edge = src_edges[src_node.id][node_id]
                    input_name = src_edge.attrs["input_name"]
                    compare_node_ids = tgt_edges[tgt_node.id].keys()

                    compare_node = None
                    for compare_node_id in compare_node_ids:
                        edge = tgt_edges[tgt_node.id][compare_node_id]
                        if (
                            edge.attrs["input_name"] == input_name
                            and compare_node_id not in result.values()
                        ):
                            compare_node = tgt_nodes[compare_node_id]
                            break
659

660 661 662 663
                    if not compare_node:
                        not_matched = True
                        return
                    _match_core(src_nodes[node_id], compare_node)
664

665 666 667 668 669
                # as output for op node
                src_as_output_nodes = src_reverse_adjs[src_node.id]
                for node in src_as_output_nodes:
                    if node.id in result:
                        continue
C
caozhou 已提交
670

671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716
                    src_edge = src_edges[node.id][src_node.id]
                    output_name = src_edge.attrs["output_name"]

                    compare_nodes = tgt_reverse_adjs[tgt_node.id]

                    compare_node = None
                    for item in compare_nodes:
                        node_id = item.id
                        edge = tgt_edges[node_id][tgt_node.id]
                        if edge.attrs["output_name"] == output_name:
                            compare_node = tgt_nodes[node_id]
                            break
                    if not compare_node:
                        not_matched = True
                        return
                    _match_core(src_nodes[node.id], compare_node)

        results = []
        matched_ids = set()
        matched_op_node_ids = set()
        result = {}
        src_nodes = pattern.nodes
        src_edges = pattern._adjs
        src_reverse_adjs = pattern._reverse_adjs

        tgt_nodes = graph.nodes
        tgt_edges = graph._adjs
        tgt_reverse_adjs = graph._reverse_adjs
        tgt_attr_to_nodes = graph._attr_to_nodes

        # starts with a op node
        src_start_node = None
        for node_id in src_nodes:
            node = src_nodes[node_id]
            if node.attrs["type"] not in ["var", "param", "data"]:
                src_start_node = node
                break
        assert src_start_node is not None

        for node_id in tgt_nodes:
            node = tgt_nodes[node_id]
            if node.attrs["type"] == src_start_node.attrs["type"]:
                not_matched = False
                _match_core(src_start_node, node)
                if not not_matched:
                    need_to_append = True
717
                    for value in result.values():
718 719 720 721 722 723 724 725 726 727 728 729 730
                        if value in matched_op_node_ids:
                            result = {}
                            need_to_append = False
                            break
                    if need_to_append:
                        results.append(result)
                        for value in result.values():
                            matched_ids.add(value)
                            if value in graph.attrs["id_to_op"].keys():
                                matched_op_node_ids.add(value)
                        result = {}
                else:
                    not_matched = False
731
                    result = {}
732
        return results, matched_ids
733

734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756
    @staticmethod
    def match_all_patterns(graph):
        # matched_results maps pattern_name to list which contains pattern node id to graph node id mapping,
        # such as {"pattern_name": [{pattern_node_id: graph_node}, ]}
        matched_results = {}
        matched_ids = set()
        for pattern_name in _PATTERNS:
            pattern = _PATTERNS[pattern_name]
            results, matched = GraphUtil.match_pattern(pattern, graph)
            for result in results:
                has_matched = False
                for id in result:
                    if result[id] in matched_ids:
                        has_matched = True
                        break
                if not has_matched:
                    for item in result:
                        matched_ids.add(result[id])
                    if pattern.name not in matched_results:
                        matched_results[pattern.name] = []
                    matched_results[pattern.name].append(result)

        return matched_results
757 758 759


class OperatorClusteringUtil:
760 761
    """Operator clustering util is used to cluster operators to layers."""

C
caozhou 已提交
762 763 764 765 766
    common_starts = ["layer_norm", "matmul_v2", "matmul"]

    @staticmethod
    def get_ranks(seq):
        """Get rank array of the given seq by doubled algorithm."""
767
        ordered_seq = sorted(set(seq))
C
caozhou 已提交
768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
        item_to_rank = {item: idx for idx, item in enumerate(ordered_seq)}
        inter_ranks = [item_to_rank[item] for item in seq]

        length = len(inter_ranks)
        power = 0
        interval = 2**power
        while interval < length:
            for idx, item in enumerate(inter_ranks):
                if idx + interval >= length:
                    inter_ranks[idx] = [item, -1]
                else:
                    inter_ranks[idx] = [item, inter_ranks[idx + interval]]

            tmp = []
            for item in inter_ranks:
                if item not in tmp:
                    tmp.append(item)
            tmp.sort(key=lambda x: (x[0], x[1]))
            item_to_rank = {}
            for idx, val in enumerate(tmp):
                key = ",".join(str(item) for item in val)
                item_to_rank[key] = idx

            inter_ranks = [
                item_to_rank[",".join(str(val) for val in item)]
                for item in inter_ranks
            ]
            power += 1
            interval = 2**power

        return inter_ranks

    @staticmethod
    def get_suffixes(ranks):
        """Get suffix array by the given rank array."""
        suffixes = [0 for idx in range(len(ranks))]
        for idx, item in enumerate(ranks):
            suffixes[item] = idx
        return suffixes

    @staticmethod
    def get_heights(suffixes, seq):
        """Get height array by the suffix array and seq"""
        heights = [-1 for i in range(len(suffixes))]
        for i in range(1, len(seq)):
            x = seq[suffixes[i - 1] :]
            y = seq[suffixes[i] :]
            max_len = len(x) if len(x) > len(y) else len(y)
            same_count = 0
            for j in range(max_len):
                if j >= len(x) or j >= len(y):
                    break
                else:
                    if x[j] == y[j]:
                        same_count += 1
                    else:
                        break
            heights[i] = same_count

        return heights

    @staticmethod
    def get_longest_repeated_sub_seq(suffixes, heights, seq):
        """Get longest repeated sub sequence by suffix array algorithm."""
        length = len(seq)
        if length <= 1:
            return None
        k = length // 2
        height_groups = []
        longest_sub_seq = None
        longest_sub_seqs = []

        while k >= 2:
            height_group = []
            for i in range(1, len(heights)):
                if heights[i] >= k:
                    if i == 1:
                        height_group.append(0)
                    height_group.append(i)
                else:
                    if i == 1:
                        height_groups.append([0])
                        height_group = [i]
                    else:
                        height_groups.append(height_group)
                        height_group = [i]

            if height_group:
                height_groups.append(height_group)

            for height_group in height_groups:
                suffix_group = []
                index_group = []
                for idx in height_group:
                    suffix_group.append(idx)
                    index_group.append(suffixes[idx])

                max_index = max(index_group)
                min_index = min(index_group)
                if max_index - min_index >= k:
                    longest_sub_seq = seq[min_index : min_index + k]
869 870 871 872
                    if (
                        longest_sub_seq[0]
                        in OperatorClusteringUtil.common_starts
                    ):
C
caozhou 已提交
873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934
                        return longest_sub_seq
            if longest_sub_seq is not None:
                return longest_sub_seq

            k -= 1
            height_groups = []

        return longest_sub_seq

    @staticmethod
    def get_decomposed_sub_seq(seq):
        """Get decomposed sub seq s by seq S such as s * R = S."""
        if not seq:
            return seq

        decomposed_sub_seq = seq
        seq_len = len(seq)
        if seq_len == 1:
            return decomposed_sub_seq
        else:
            for interval in range(2, seq_len + 1):
                if seq_len % interval == 0:
                    repeated_times = seq_len // interval
                    decomposed_sub_seq = seq[0:interval]
                    decomposed = True
                    for j in range(1, repeated_times + 1):
                        sub_seq = seq[interval * (j - 1) : interval * j]
                        if sub_seq != decomposed_sub_seq:
                            decomposed = False
                            break
                    if decomposed:
                        return decomposed_sub_seq

        return decomposed_sub_seq

    @staticmethod
    def replace_by_decomposed_seq(sub_seq, seq):
        """Replace seq by sub seq."""
        if not sub_seq:
            return seq

        result = []
        sub_seq_len = len(sub_seq)
        i = 0
        while i < len(seq):
            if seq[i : i + sub_seq_len] == sub_seq:
                result.append(seq[i : i + sub_seq_len])
                i += sub_seq_len
            else:
                result.append(seq[i])
                i += 1

        return result

    @staticmethod
    def stop_replace(seq):
        for item in seq:
            if not isinstance(item, list):
                return False
        return True


935
class ClusterPartitionUtil:
936 937
    """Cluster partition util is used to get device meshes and process meshes."""

938 939 940 941 942 943 944 945 946 947
    @staticmethod
    def factorization(num):
        factors = []
        for i in range(1, int(math.floor(math.sqrt(num))) + 1):
            if num % i == 0:
                factors.append([i, int(num / i)])
        return factors

    @staticmethod
    def complete_meshes(partitions: list, num: int):
948 949 950 951 952
        if num == 2:
            return [[1, 2], [2, 1]]
        if num == 3:
            return [[1, 2], [2, 1], [1]]
        # special cases
953 954 955 956 957 958 959 960 961 962 963 964 965 966
        if len(partitions) == 1:
            partitions = ClusterPartitionUtil.factorization(num - 1)
            partitions.append([1])
        return partitions

    @staticmethod
    def partition_cluster(
        n: int,
        m: int,
        filter=[
            complete_meshes.__func__,
        ],
    ) -> list:
        """
967
        Partiton cluster into possible device meshes.
968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031
        Args:
            n (int): The number of nodes.
            m (int): The number of single devices on each node.
            filter (list): Functions for filtering useful meshes
        Returns:
            device_meshed (list) : The possible device meshes.
        """
        partition_result = ClusterPartitionUtil.factorization(n)
        for func in filter:
            partition_result = func(partition_result, n)
        device_meshes = []
        if n == 1:
            partition_result = ClusterPartitionUtil.factorization(m)
            for partition in partition_result:
                device_mesh = []
                for i in range(partition[0]):
                    device_mesh.append([1, partition[1]])
                device_meshes.append(device_mesh)
        else:
            incerement = 1 if partition_result[-1] == [1] else 0
            for partition in partition_result:
                if len(partition) < 2:
                    continue
                device_mesh = []
                for i in range(partition[0]):
                    device_mesh.append([partition[1], m])
                device_mesh[-1][0] += incerement
                device_meshes.append(device_mesh)

        return device_meshes


def convert_to_process_meshes(device_mesh: list) -> list:
    """
    Transfer device_meshes into possible process meshes.
    Args:
        device meshes (list): [n,m], one device mesh.
    Returns:
        process_meshes (list): Possible process_meshes
    """
    n, m = device_mesh[0], device_mesh[1]
    factors = (
        ClusterPartitionUtil.factorization(m)
        if n == 1
        else ClusterPartitionUtil.factorization(n)
    )
    process_meshes = []
    if n == 1:
        for factor in factors:
            if factor[0] == 1:
                process_meshes.append([factor[1]])
                continue
            process_meshes.append(factor)
    else:
        for factor in factors:
            mul1, mul2 = factor[0], factor[1]
            if mul1 == 1:
                process_meshes.append([m * mul2])
            elif mul1 != mul2:
                process_meshes.append([int(n / mul2), m * mul2])
            process_meshes.append([int(n / mul1), m * mul1])
    return process_meshes


C
caozhou 已提交
1032
class RuleBasedTuner:
1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045
    """
    A tuner based on rule from expert experience to search a good parallel strategy.
    Args:
        dist_context (DistributedContext): The distributed context.
        mode (str): The mode of current task, it can be train or eval. Default: train.
        level (str): The level of this tuner, it can be o1 or o2.
                     o2 level may find better strategy but need more time than o1.
                     If level is o1, it means all layers within same parallelism and place layers evenly when in pipeline parallism.
                     If level is o2, it means layers can has own parallelism and place layers may not evenly.
                     Default: o1.
    """

    def __init__(self, dist_context, mode="train", level="o1"):
C
caozhou 已提交
1046
        self._dist_context = dist_context
1047
        self._cluster = self._dist_context.cluster
C
caozhou 已提交
1048
        self._mode = mode
1049 1050 1051 1052
        assert level in ["o1", "o2"]
        self._level = level
        self._logger = get_logger(logging.INFO)
        self._use_dp = False
C
caozhou 已提交
1053

1054 1055
        # forward sub program
        self.fwd_sub_programs = OrderedDict()
C
caozhou 已提交
1056

1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193
        # dist_context of sub program
        self.sub_programs_dist_context = OrderedDict()

        # graph of forward sub program
        self.fwd_sub_program_graphs = OrderedDict()

        # full main program
        self.full_main_program = None

        # full startup program
        self.full_startup_program = None

        # full main program dist context
        self.full_main_program_dist_context = None

        # tensor dist attribute from pattern setting
        self.tensor_dist_attrs = {}

        # op original id to op mapping
        self.op_original_id_to_op = {}

        # op original id to op idx in program
        self.op_original_id_to_idx = {}

        # op original id to grad op original id mapping
        self.op_original_id_to_grad_op_original_id = {}

        # all process meshes that the cluster can express
        self.process_meshes = []

        # all device meshes that the cluster can be partitioned
        self.device_meshes_list = []

        # the best cost of stage in a given device mesh
        self.stage_best_cost_of_dm = {}

        # the best cost of stage in a given process mesh
        self.stage_best_cost_of_pm = {}

        # the op clustering result
        self.layers = []

        self._is_run = True
        if os.getenv("PADDLE_AUTO_PARALLEL_STAGE") != "tuner":
            self._is_run = True
        else:
            self._is_run = False
        self._strategy_path = None
        if self._dist_context._json_config:
            try:
                self._strategy_path = self._dist_context._json_config[
                    "tuner_save_path"
                ]
            except:
                self._strategy_path = None

    @property
    def dist_context(self):
        return self._dist_context

    @property
    def cluster(self):
        return self._cluster

    @property
    def mode(self):
        return self._mode

    @property
    def level(self):
        return self._level

    def convert_process_mesh_to_key(self, process_mesh):
        """Convert process mesh object to str."""
        processes = ",".join([str(x) for x in process_mesh._process_ids])
        topology = ",".join([str(x) for x in process_mesh._shape])
        key = processes + ";" + topology
        return key

    def gen_full_program(self):
        """Generate full program that contain backward and update phase program if mode is train."""
        self.full_main_program = self.dist_context.serial_main_program.clone()
        if self.mode == "train":
            self.full_startup_program = (
                self.dist_context.serial_startup_program.clone()
            )
            loss = self.full_main_program.global_block().vars[
                self.dist_context.serial_loss.name
            ]
            serial_optimizer = self._dist_context.serial_optimizer
            optimizer = copy.deepcopy(serial_optimizer)
            self.full_main_program_dist_context = DistributedContext(
                serial_main_prog=self.full_main_program,
                serial_startup_prog=self.full_startup_program,
                serial_loss=loss,
            )
            # if in train mode, generate backward and update program.
            with program_guard(
                self.full_main_program, self.full_startup_program
            ):
                params_grads = append_backward(
                    loss,
                    distop_context=self.full_main_program_dist_context.dist_op_context,
                )

            with program_guard(
                self.full_main_program, self.full_startup_program
            ):
                with unique_name.guard("opt_"):
                    optimizer_ops = optimizer.apply_gradients(params_grads)

            # op original id to grad op id
            for idx, op in enumerate(self.full_main_program.global_block().ops):
                self.op_original_id_to_op[op.desc.original_id()] = op
                self.op_original_id_to_idx[op.desc.original_id()] = idx

            grad_op_id_to_op_id = (
                self.full_main_program_dist_context.dist_op_context.grad_op_id_to_op_id
            )

            for grad_op_original_id in grad_op_id_to_op_id:
                op_id = grad_op_id_to_op_id[grad_op_original_id]
                self.op_original_id_to_grad_op_original_id[
                    op_id
                ] = grad_op_original_id

    def cluster_operators(self):
        """Group operators to layers."""
        ops = self._dist_context._serial_main_program.global_block().ops

        # clear op dist attr when user shard tensor or op but in the full auto parallel mode.
        for op in ops:
            op.dist_attr = OperatorDistAttr(op.desc)

        vars = self._dist_context._serial_main_program.global_block().vars
        for var_name in vars:
            vars[var_name].dist_attr = TensorDistAttr(vars[var_name].desc)
C
caozhou 已提交
1194 1195 1196

        seq = [op.type for op in ops]

1197
        while not OperatorClusteringUtil.stop_replace(seq):
C
caozhou 已提交
1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
            to_replace_seq = []
            to_replace_idxes = []
            has_append = False
            for idx, item in enumerate(seq):
                if not isinstance(item, list):
                    has_append = True
                    to_replace_seq.append(item)
                    to_replace_idxes.append(idx)
                elif isinstance(seq, list) and not has_append:
                    continue
                elif isinstance(seq, list) and has_append:
                    break

1211 1212 1213 1214 1215 1216 1217 1218 1219
            ranks = OperatorClusteringUtil.get_ranks(to_replace_seq)
            suffixes = OperatorClusteringUtil.get_suffixes(ranks)
            heights = OperatorClusteringUtil.get_heights(
                suffixes, to_replace_seq
            )
            longest_sub_seq = (
                OperatorClusteringUtil.get_longest_repeated_sub_seq(
                    suffixes, heights, to_replace_seq
                )
C
caozhou 已提交
1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237
            )
            has_merged = False
            if longest_sub_seq is None:
                for i in range(to_replace_idxes[-1] + 1, len(seq)):
                    if isinstance(seq[i], list):
                        seq[i] = to_replace_seq + seq[i]
                        has_merged = True
                        break
                if not has_merged:
                    for i in range(to_replace_idxes[0] - 1, -1, -1):
                        if isinstance(seq[i], list):
                            seq[i].extend(to_replace_seq)
                            has_merged = True
                            break
                if not has_merged:
                    seq = [to_replace_seq]
                    break

1238
            decomposed_sub_seq = OperatorClusteringUtil.get_decomposed_sub_seq(
C
caozhou 已提交
1239 1240
                longest_sub_seq
            )
1241
            to_replace_seq = OperatorClusteringUtil.replace_by_decomposed_seq(
C
caozhou 已提交
1242 1243
                decomposed_sub_seq, to_replace_seq
            )
1244

C
caozhou 已提交
1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260
            result = seq[: to_replace_idxes[0]]
            if not has_merged:
                result.extend(to_replace_seq)
            result.extend(seq[to_replace_idxes[-1] + 1 :])
            seq = result

        layers = []
        idx = 0
        for groups in seq:
            layer = []
            for op in groups:
                layer.append(ops[idx])
                idx += 1
            layers.append(layer)

        return layers
1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626

    def match_program(self, program):
        """Use patterns to match the program and get tensor shard spec when pattern matched."""
        graph = GraphUtil.convert_to_graph(program.global_block())
        results = GraphUtil.match_all_patterns(graph)
        if results:
            for pattern_name in results.keys():
                pattern = _PATTERNS[pattern_name]
                for parallelism in pattern.attrs["shard_spec"].keys():
                    shard_spec = pattern.attrs["shard_spec"][parallelism]
                    for pattern_node_id in shard_spec.keys():
                        for item in results[pattern_name]:
                            var_id = item[pattern_node_id]
                            var_desc_id = graph.attrs["id_to_var_desc_id"][
                                var_id
                            ]
                            if var_desc_id not in self.tensor_dist_attrs:
                                self.tensor_dist_attrs[var_desc_id] = {}
                            self.tensor_dist_attrs[var_desc_id][
                                parallelism
                            ] = shard_spec[pattern_node_id]
                            tensor_name = graph.attrs["id_to_var_name"][var_id]
                            self._logger.info(
                                "{}'s shard_spec may be {} when under {} parallelism.".format(
                                    tensor_name,
                                    shard_spec[pattern_node_id],
                                    parallelism,
                                )
                            )
        else:
            self._logger.info(
                "No pattern has be matched by this program. Currently, only the transformer-based models are supported. Data parallelism will be used."
            )
            self._use_dp = True

    def gen_fwd_sub_programs_by_clone(self):
        """Generate all forward sub programs by cloned from the original program."""
        for idx, layer in enumerate(self.layers):
            sub_fwd_program = self._gen_fwd_sub_program_by_clone(layer)
            self.fwd_sub_programs[idx] = sub_fwd_program

    def _gen_fwd_sub_program_by_clone(self, ops):
        """Generate the forward sub program of the given ops."""
        program = paddle.static.Program()
        block = ops[0].block
        vars = block.vars
        target_block = program.global_block()
        with paddle.static.program_guard(program):
            has_cloned_vars = set()
            for op in ops:
                new_op_desc = target_block.desc.append_op()
                new_op_desc.copy_from(op.desc)
                for var_name in op.input_arg_names:
                    if var_name not in has_cloned_vars:
                        if vars[var_name].is_parameter:
                            src_var = vars[var_name]
                            copied_kwargs = {}
                            copied_kwargs['trainable'] = src_var.trainable
                            copied_kwargs[
                                'optimize_attr'
                            ] = src_var.optimize_attr
                            copied_kwargs['regularizer'] = src_var.regularizer
                            copied_kwargs[
                                'do_model_average'
                            ] = src_var.do_model_average
                            copied_kwargs['need_clip'] = src_var.need_clip

                            param = Parameter(
                                block=target_block,
                                type=src_var.type,
                                name=src_var.name,
                                shape=src_var.shape,
                                dtype=src_var.dtype,
                                lod_level=src_var.lod_level,
                                error_clip=src_var.error_clip,
                                stop_gradient=src_var.stop_gradient,
                                is_data=src_var.is_data,
                                belong_to_optimizer=src_var.belong_to_optimizer,
                                **copied_kwargs
                            )
                        else:
                            target_block._clone_variable(vars[var_name])
                            target_block.vars[var_name].persistable = vars[
                                var_name
                            ].persistable
                        target_block.vars[var_name].desc.set_original_id(
                            vars[var_name].desc.original_id()
                        )
                        has_cloned_vars.add(var_name)

                for var_name in op.output_arg_names:
                    if var_name not in has_cloned_vars:
                        target_block._clone_variable(vars[var_name])
                        target_block.vars[var_name].persistable = vars[
                            var_name
                        ].persistable
                        target_block.vars[var_name].desc.set_original_id(
                            vars[var_name].desc.original_id()
                        )
                        has_cloned_vars.add(var_name)

        target_block._sync_with_cpp()

        return program

    def _compelte_sub_fwd_program(self, idx, sub_fwd_program, process_mesh):
        """Compelete forward sub  program."""
        selective_parallelisms = (
            ["dp", "mp"] if len(process_mesh.shape) == 1 else ["dp_mp", "mp_dp"]
        )
        for parallelism in selective_parallelisms:
            has_set_tensor_count = 0
            dist_context = DistributedContext(sub_fwd_program)
            has_set_dist_attr_tensors = set()
            dist_context.process_meshes = []
            dist_context.add_process_mesh(process_mesh)
            vars = sub_fwd_program.global_block().vars

            # clear op dist attr
            ops = sub_fwd_program.global_block().ops
            for op in ops:
                op.dist_attr = OperatorDistAttr(op.desc)
            # clear tensor dist attr
            for var_name in vars:
                vars[var_name].dist_attr = TensorDistAttr(vars[var_name].desc)

            for var_name in vars:
                var_id = vars[var_name].desc.original_id()
                if var_id in self.tensor_dist_attrs:
                    if parallelism in self.tensor_dist_attrs[var_id]:
                        dims_mapping = self.tensor_dist_attrs[var_id][
                            parallelism
                        ]
                        dist_tensor = DistributedTensor(vars[var_name])
                        dist_tensor.dist_attr.process_mesh = process_mesh
                        dist_tensor.dist_attr.dims_mapping = dims_mapping
                        dist_tensor.dist_attr.mark_annotated("dims_mapping")
                        dist_tensor.dist_attr.mark_annotated("process_mesh")
                        dist_context.add_dist_tensor_for_program(dist_tensor)
                        has_set_tensor_count += 1
                        has_set_dist_attr_tensors.add(var_id)

            # check whether no dist attr in dist context
            if has_set_tensor_count > 0:
                dist_context.initialize(no_default=True)
                completer = Completer(dist_context)
                completer.complete_forward_annotation()
                if parallelism not in self.sub_programs_dist_context[idx]:
                    self.sub_programs_dist_context[idx][parallelism] = {}
                key = self.convert_process_mesh_to_key(process_mesh)
                self.sub_programs_dist_context[idx][parallelism][
                    key
                ] = dist_context
            else:
                self._logger.info(
                    "No pattern has be matched under {} parallelism whe sub program is {}.".format(
                        parallelism, sub_fwd_program
                    )
                )

    def complete_sub_fwd_programs(self, process_mesh):
        """Complete all forward sub programs."""
        for idx in self.fwd_sub_programs.keys():
            sub_fwd_program = self.fwd_sub_programs[idx]
            if idx not in self.sub_programs_dist_context:
                self.sub_programs_dist_context[idx] = {}
            self._compelte_sub_fwd_program(idx, sub_fwd_program, process_mesh)

    def _complete_sub_bwd_program(self, sub_program_dist_context):
        """
        Complete the backward OP according to the forward OP.
        Most of the logic is the same as the backward completion in the completer.
        The difference is that find the backward OP according to the forward OP,
        while find the forward OP according to the backward OP in the completer.
        """

        def _is_grad_var_name(name):
            if "@GRAD" in name:
                return True
            return False

        sub_fwd_program = sub_program_dist_context.serial_main_program
        block = sub_fwd_program.global_block()
        vars = self.full_main_program.global_block().vars
        ops = self.full_main_program.global_block().ops
        grad_var_to_var = (
            self.full_main_program_dist_context.dist_op_context.grad_var_to_var[
                1
            ]
        )
        for forward_op in block.ops:
            if (
                forward_op.desc.original_id()
                not in self.op_original_id_to_grad_op_original_id
            ):
                continue
            grad_op_id = self.op_original_id_to_grad_op_original_id[
                forward_op.desc.original_id()
            ]
            # for unsqueeze2 op in gpt, it has no grad op
            # or for no need to bwd
            if grad_op_id not in self.op_original_id_to_op:
                continue
            grad_op = self.op_original_id_to_op[grad_op_id]
            if grad_op.type == "concat" and forward_op.type == "split":
                forward_op_dist_attr = (
                    sub_program_dist_context.get_op_dist_attr_for_program(
                        forward_op
                    )
                )
                output_var = vars[grad_op.desc.output('Out')[0]]
                split_input_var_name = forward_op.input("X")[0]
                ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping(
                    split_input_var_name
                )
                ref_mesh = forward_op_dist_attr.process_mesh

                grad_op_dist_attr = OperatorDistAttr()
                for input_name in grad_op.input_arg_names:
                    grad_op_dist_attr.set_input_dims_mapping(
                        input_name, ref_dims_mapping
                    )

                output_var_dist_attr = TensorDistAttr()
                output_var_dist_attr.dims_mapping = ref_dims_mapping
                output_var_dist_attr.process_mesh = ref_mesh
                sub_program_dist_context.set_tensor_dist_attr_for_program(
                    output_var, output_var_dist_attr
                )

                grad_op_dist_attr.set_output_dims_mapping(
                    output_var.name, ref_dims_mapping
                )
                grad_op_dist_attr.process_mesh = ref_mesh
                sub_program_dist_context.set_op_dist_attr_for_program(
                    grad_op, grad_op_dist_attr
                )
                grad_op_dist_attr.impl_type = (
                    fwd_op_dist_attr.impl_type  # noqa: F821
                )
                grad_op_dist_attr.impl_idx = (
                    fwd_op_dist_attr.impl_idx  # noqa: F821
                )
                continue

            fwd_op_dist_attr = (
                sub_program_dist_context.get_op_dist_attr_for_program(
                    forward_op
                )
            )
            fwd_op_process_mesh = fwd_op_dist_attr.process_mesh
            grad_op_dist_attr = OperatorDistAttr()
            grad_op_dist_attr.process_mesh = fwd_op_process_mesh

            for input_name in grad_op.input_arg_names:
                if (
                    input_name not in forward_op.input_arg_names
                    and input_name not in forward_op.output_arg_names
                ):
                    if input_name in grad_var_to_var.keys():
                        fwd_name = grad_var_to_var[input_name]
                        ref_dims_mapping = (
                            fwd_op_dist_attr.get_output_dims_mapping(fwd_name)
                        )
                    else:
                        input_var = vars[input_name]
                        ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program(
                            input_var
                        ).dims_mapping
                else:
                    if input_name in forward_op.input_arg_names:
                        ref_dims_mapping = (
                            fwd_op_dist_attr.get_input_dims_mapping(input_name)
                        )
                    else:
                        ref_dims_mapping = (
                            fwd_op_dist_attr.get_output_dims_mapping(input_name)
                        )
                assert (
                    ref_dims_mapping is not None
                ), "[{}] 's dims mapping is NONE".format(input_name)
                grad_op_dist_attr.set_input_dims_mapping(
                    input_name, ref_dims_mapping
                )

            for output_name in grad_op.output_arg_names:
                assert output_name in grad_var_to_var
                fwd_name = grad_var_to_var[output_name]
                ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping(
                    fwd_name
                )
                # var
                output_var = vars[output_name]
                tensor_dist_attr = TensorDistAttr()
                tensor_dist_attr.dims_mapping = ref_dims_mapping
                tensor_dist_attr.process_mesh = fwd_op_process_mesh
                sub_program_dist_context.set_tensor_dist_attr_for_program(
                    output_var, tensor_dist_attr
                )
                # op
                grad_op_dist_attr.set_output_dims_mapping(
                    output_name, ref_dims_mapping
                )

            grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type
            grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx
            sub_program_dist_context.set_op_dist_attr_for_program(
                grad_op, grad_op_dist_attr
            )

            grad_op_idx = self.op_original_id_to_idx[grad_op_id]
            if grad_op_idx + 1 < len(ops):
                grad_op_next_op = ops[grad_op_idx + 1]
                if grad_op_next_op.type == "sum":
                    assert all(
                        map(_is_grad_var_name, grad_op_next_op.input_arg_names)
                    )
                    output_name = grad_op_next_op.output_arg_names[0]
                    assert (
                        output_name in grad_var_to_var
                    ), "sum op's output '{}' has no corresponding var".format(
                        output_name
                    )
                    ref_fwd_var_name = grad_var_to_var[output_name]
                    ref_fwd_var = vars[ref_fwd_var_name]
                    ref_fwd_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                        ref_fwd_var
                    )
                    ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping
                    ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh

                    # output
                    tensor_dist_attr = TensorDistAttr()
                    tensor_dist_attr.dims_mapping = ref_fwd_dims_mapping
                    tensor_dist_attr.process_mesh = ref_fwd_process_mesh
                    output_var = vars[output_name]
                    sub_program_dist_context.set_tensor_dist_attr_for_program(
                        output_var, tensor_dist_attr
                    )

                    # op
                    grad_op_dist_attr = OperatorDistAttr()
                    grad_op_dist_attr.process_mesh = ref_fwd_process_mesh

                    for var_name in grad_op_next_op.input_arg_names:
                        grad_op_dist_attr.set_input_dims_mapping(
                            var_name, ref_fwd_dims_mapping
                        )
                    grad_op_dist_attr.set_output_dims_mapping(
                        output_name, ref_fwd_dims_mapping
                    )
                    grad_op_dist_attr.impl_type = "default"
                    grad_op_dist_attr.impl_idx = 0

                    sub_program_dist_context.set_op_dist_attr_for_program(
                        grad_op_next_op, grad_op_dist_attr
                    )

    def complete_sub_bwd_programs(self):
        for idx in self.sub_programs_dist_context:
            for parallelism in self.sub_programs_dist_context[idx]:
                for key in self.sub_programs_dist_context[idx][parallelism]:
                    sub_program_dist_context = self.sub_programs_dist_context[
                        idx
                    ][parallelism][key]
                    self._complete_sub_bwd_program(sub_program_dist_context)
1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876 1877 1878 1879 1880 1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997 1998 1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041 2042 2043 2044 2045 2046 2047 2048 2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115 2116 2117 2118 2119 2120 2121 2122 2123 2124 2125 2126 2127 2128 2129 2130 2131 2132 2133 2134 2135 2136 2137 2138 2139 2140 2141 2142 2143 2144 2145 2146 2147 2148 2149 2150 2151 2152 2153 2154 2155 2156 2157 2158 2159 2160 2161 2162 2163 2164 2165 2166 2167 2168 2169 2170 2171 2172 2173 2174 2175 2176 2177 2178 2179 2180 2181 2182 2183 2184 2185 2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205 2206 2207 2208 2209 2210 2211 2212 2213 2214 2215 2216 2217 2218 2219 2220 2221 2222 2223 2224 2225 2226

    def _complete_sub_update_program(self, sub_program_dist_context):
        """
        Complete the opt OP according to the tensor.
        Most of the logic is the same as the update completion in the completer.
        """
        world_ranks = ProcessMesh(
            [
                i
                for i in range(
                    self._cluster.get_num_machines()
                    * self._cluster._num_devices_per_machine
                )
            ]
        )
        dist_tensors = sub_program_dist_context._dist_tensors_for_program

        vars = self.full_main_program.global_block().vars
        ops = self.full_main_program.global_block().ops
        learning_rate_completed = False
        for idx in range(len(ops)):
            op = ops[idx]
            if int(op.attr('op_role')) == int(OpRole.Optimize):
                if is_gradient_clip_op(op):
                    if op.type in [
                        "sum",
                        "sqrt",
                        "fill_constant",
                        "elementwise_max",
                        "elementwise_div",
                    ]:
                        op_dist_attr = OperatorDistAttr()
                        op_dist_attr.process_mesh = world_ranks
                        for in_name in op.input_arg_names:
                            in_var = vars[in_name]
                            if in_var.desc.original_id() in dist_tensors:
                                in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                                    in_var
                                )
                                op_dist_attr.set_input_dist_attr(
                                    in_name, in_dist_attr
                                )
                            else:
                                in_dist_attr = TensorDistAttr()
                                in_dist_attr.process_mesh = world_ranks
                                in_dist_attr.dims_mapping = [
                                    -1 for _ in range(len(in_var.shape))
                                ]
                                op_dist_attr.set_input_dist_attr(
                                    in_name, in_dist_attr
                                )
                                sub_program_dist_context.set_tensor_dist_attr_for_program(
                                    in_var, in_dist_attr
                                )
                        for out_name in op.output_arg_names:
                            out_var = vars[out_name]
                            if out_var.desc.original_id() in dist_tensors:
                                out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                                    out_var
                                )
                                op_dist_attr.set_output_dist_attr(
                                    out_name, out_dist_attr
                                )
                            else:
                                out_dist_attr = TensorDistAttr()
                                out_dist_attr.process_mesh = world_ranks
                                out_dist_attr.dims_mapping = [
                                    -1 for _ in range(len(out_var.shape))
                                ]
                                sub_program_dist_context.set_tensor_dist_attr_for_program(
                                    out_var, out_dist_attr
                                )
                                op_dist_attr.set_output_dist_attr(
                                    out_name, out_dist_attr
                                )
                        sub_program_dist_context.set_op_dist_attr_for_program(
                            op, op_dist_attr
                        )
                    else:
                        in_var = vars[op.input("X")[0]]
                        if in_var.desc.original_id() in dist_tensors:
                            in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                                in_var
                            )
                            assert in_dist_attr is not None
                            ref_process_mesh = in_dist_attr.process_mesh
                            ref_dims_mapping = in_dist_attr.dims_mapping

                            if (
                                op.type == "cast"
                                and ops[idx + 1].type == "elementwise_mul"
                            ):
                                ref_var = vars[ops[idx + 1].input("X")[0]]
                                ref_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                                    ref_var
                                )
                                assert ref_dist_attr is not None
                                ref_process_mesh = ref_dist_attr.process_mesh

                            out_var = vars[op.output("Out")[0]]
                            out_dist_attr = TensorDistAttr()
                            out_dist_attr.process_mesh = ref_process_mesh
                            if out_var.shape == in_var.shape:
                                out_dist_attr.dims_mapping = ref_dims_mapping
                            else:
                                assert (
                                    len(out_var.shape) == 1
                                    and out_var.shape[0] == 1
                                )
                                out_dist_attr.dims_mapping = [-1]
                            sub_program_dist_context.set_tensor_dist_attr_for_program(
                                out_var, out_dist_attr
                            )

                            op_dist_attr = OperatorDistAttr()
                            op_dist_attr.process_mesh = ref_process_mesh
                            for in_name in op.input_arg_names:
                                in_var = vars[in_name]
                                in_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                                    in_var
                                )
                                op_dist_attr.set_input_dims_mapping(
                                    in_name, in_dist_attr.dims_mapping
                                )
                            for out_name in op.output_arg_names:
                                out_var = vars[out_name]
                                out_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                                    out_var
                                )
                                op_dist_attr.set_output_dims_mapping(
                                    out_name, out_dist_attr.dims_mapping
                                )
                            op_dist_attr.set_input_dist_attr(
                                in_var.name, in_dist_attr
                            )
                            op_dist_attr.set_output_dist_attr(
                                out_var.name, out_dist_attr
                            )

                            sub_program_dist_context.set_op_dist_attr_for_program(
                                op, op_dist_attr
                            )
                        else:
                            continue

                if "Grad" in op.input_names and "Param" in ops[idx].input_names:
                    assert (
                        len(op.input("Param")) == 1
                    ), "Only support one-to-one now."
                    assert (
                        len(op.input("Grad")) == 1
                    ), "Only support one-to-one now."
                    param = vars[op.input("Param")[0]]
                    grad_var = vars[op.input("Grad")[0]]
                    if param.desc.original_id() in dist_tensors:
                        param_dist_attr = sub_program_dist_context.get_tensor_dist_attr_for_program(
                            param
                        )
                        assert param_dist_attr is not None
                        ref_process_mesh = sub_program_dist_context.get_tensor_dist_attr_for_program(
                            param
                        ).process_mesh
                        assert ref_process_mesh is not None
                        ref_dims_mapping = sub_program_dist_context.get_tensor_dist_attr_for_program(
                            param
                        ).dims_mapping
                        assert ref_dims_mapping is not None
                        op_dist_attr = OperatorDistAttr()
                        op_dist_attr.process_mesh = ref_process_mesh
                        op_dist_attr.set_input_dims_mapping(
                            grad_var.name, ref_dims_mapping
                        )
                        op_dist_attr.set_input_dims_mapping(
                            param.name, ref_dims_mapping
                        )
                        op_dist_attr.set_output_dims_mapping(
                            param.name, ref_dims_mapping
                        )
                        learning_var = vars[op.input("LearningRate")[0]]
                        op_dist_attr.set_input_dims_mapping(
                            learning_var.name, [-1]
                        )
                        op_dist_attr.set_output_dims_mapping(
                            learning_var.name, [-1]
                        )

                        if not learning_rate_completed:
                            learning_rate_completed = True
                            var_dist_attr = TensorDistAttr()
                            var_dist_attr.process_mesh = world_ranks
                            var_dist_attr.dims_mapping = [-1]
                            sub_program_dist_context.set_tensor_dist_attr_for_program(
                                learning_var, var_dist_attr
                            )

                        for input_name in op.desc.input_names():

                            if input_name in [
                                'Param',
                                'Grad',
                                'LearningRate',
                                "SkipUpdate",
                                "Beta1Tensor",
                                "Beta2Tensor",
                                "EpsilonTensor",
                            ]:
                                continue
                            if len(op.desc.input(input_name)) == 0:
                                continue

                            assert len(op.desc.input(input_name)) == 1
                            input_var = vars[op.desc.input(input_name)[0]]
                            input_var_attr = TensorDistAttr()

                            if (
                                "Beta1Pow" in input_name
                                or "Beta2Pow" in input_name
                            ):
                                input_var_attr.dims_mapping = [-1]
                                op_dist_attr.set_input_dims_mapping(
                                    input_var.name, [-1]
                                )
                                op_dist_attr.set_output_dims_mapping(
                                    input_var.name, [-1]
                                )
                            else:
                                input_var_attr.dims_mapping = ref_dims_mapping
                                op_dist_attr.set_input_dims_mapping(
                                    input_var.name, ref_dims_mapping
                                )
                                op_dist_attr.set_output_dims_mapping(
                                    input_var.name, ref_dims_mapping
                                )

                            input_var_attr.process_mesh = ref_process_mesh
                            sub_program_dist_context.set_tensor_dist_attr_for_program(
                                input_var, input_var_attr
                            )

                        sub_program_dist_context.set_op_dist_attr_for_program(
                            op, op_dist_attr
                        )
                        continue
                    else:
                        continue

    def complete_sub_update_programs(self):
        for idx in self.sub_programs_dist_context:
            for parallelism in self.sub_programs_dist_context[idx]:
                for key in self.sub_programs_dist_context[idx][parallelism]:
                    sub_program_dist_context = self.sub_programs_dist_context[
                        idx
                    ][parallelism][key]
                    self._complete_sub_update_program(sub_program_dist_context)

    def convert_device_mesh_to_key(self, device_mesh):
        """Convert device mesh object to str."""
        processes = ",".join([str(x) for x in device_mesh.device_ids])
        topology = ",".join([str(x) for x in device_mesh.shape])
        key = processes + ";" + topology
        return key

    def _get_sub_program_cost(self, dist_context):
        """Estimate the cost of dist context."""
        cost_estimator = CostEstimator(self.full_main_program, self._cluster)
        global_cost = cost_estimator.estimate(dist_context)
        max_memory = cost_estimator._estimate_max_memory_by_dist_op(
            dist_context
        )
        return global_cost.time, max_memory

    def combine_dist_contexts(self, dist_contexts):
        """Combine the dist attr in dist contexts to one dist context."""
        combined_dist_context = DistributedContext()
        # set dist tensor, pay attention to shared param or var as input for multi op
        for dist_context in dist_contexts:
            for tensor_id in dist_context._dist_tensors_for_program:
                dist_tensor = dist_context._dist_tensors_for_program[tensor_id]
                if (
                    tensor_id
                    not in combined_dist_context._dist_tensors_for_program
                ):
                    combined_dist_context.add_dist_tensor_for_program(
                        dist_tensor
                    )

            # set dist op
            for op_id in dist_context._dist_ops_for_program:
                dist_op = dist_context._dist_ops_for_program[op_id]
                combined_dist_context.add_dist_op_for_program(dist_op)

            for process_mesh in dist_context.process_meshes:
                combined_dist_context.add_process_mesh(process_mesh)

        return combined_dist_context

    def prepare(self):
        """Prepare the sub program, tensor dist attr setting, device meshes and so on that tuner need."""

        # step1: cluster operators to layers
        begin = time.time()
        self.layers = self.cluster_operators()
        end = time.time()
        self._logger.info(
            "Cluster operators to {} layers in {}s.".format(
                len(self.layers), end - begin
            )
        )

        # step2: generate sub program of each layer
        begin = time.time()
        self.gen_fwd_sub_programs_by_clone()
        end = time.time()
        self._logger.info(
            "Generate programs of every layer in {}s.".format(end - begin)
        )

        # step3: partition devices to device meshes
        begin = time.time()
        n, m = (
            self._cluster.get_num_machines(),
            self._cluster._num_devices_per_machine,
        )
        device_meshes_list = ClusterPartitionUtil.partition_cluster(n, m)
        end = time.time()
        self._logger.info("Partition cluster in {}s.".format(end - begin))

        # step4: transform device mesh to process meshes
        dm_idx = 0
        for device_meshes in device_meshes_list:
            has_used_devices = 0
            self.device_meshes_list.append([])
            for device_mesh in device_meshes:
                devices = reduce(lambda x, y: x * y, device_mesh)
                processes = [
                    i
                    for i in range(has_used_devices, has_used_devices + devices)
                ]
                device_mesh_shape = (
                    device_mesh
                    if device_mesh[0] != 1
                    else [device_mesh[i] for i in range(1, len(device_mesh))]
                )
                self.device_meshes_list[-1].append(
                    DeviceMesh(
                        mesh=np.array(processes)
                        .reshape(device_mesh_shape)
                        .tolist(),
                        name="device_mesh_" + str(dm_idx),
                    )
                )
                dm_idx += 1
                has_used_devices += devices
                process_mesh_shapes = convert_to_process_meshes(device_mesh)
                for process_mesh_shape in process_mesh_shapes:
                    process_mesh = ProcessMesh(
                        np.array(processes).reshape(process_mesh_shape).tolist()
                    )
                    if process_mesh not in self.process_meshes:
                        self.process_meshes.append(process_mesh)

        # step5: generate full program
        begin = time.time()
        self.gen_full_program()
        end = time.time()
        self._logger.info("Generate full program in {}s.".format(end - begin))

        # step6: complete forward sub programs
        begin = time.time()
        for process_mesh in self.process_meshes:
            self.complete_sub_fwd_programs(process_mesh)
        end = time.time()
        self._logger.info(
            "Complete all sub forward programs in {}s.".format(end - begin)
        )

        if self.mode == "train":
            # step7: complete backward sub programs
            begin = time.time()
            self.complete_sub_bwd_programs()
            end = time.time()
            self._logger.info(
                "Complete all sub backward programs in {}s.".format(end - begin)
            )

            # step8: complete update sub programs
            begin = time.time()
            self.complete_sub_update_programs()
            end = time.time()
            self._logger.info(
                "Complete all sub update programs in {}s.".format(end - begin)
            )

    def tune_o1(self):
        """The o1 level tuning."""
        best_cost = sys.maxsize
        best_dist_context = None

        for device_meshes in self.device_meshes_list:
            pp_stages = len(device_meshes)
            average_layers = len(self.layers) // pp_stages
            device_mesh_shape = device_meshes[0].shape
            if len(device_mesh_shape) == 1:
                device_mesh_shape.insert(0, 1)
            process_mesh_shapes = convert_to_process_meshes(device_mesh_shape)

            # For example, device_mesh is [1, 8] and process_mesh is [8].
            # The selective parallelism is dp or mp
            # Get dp8 or mp8 cost and compare them to get best sreategy.
            for parallelism in ["dp", "mp", "dp_mp", "mp_dp"]:
                for process_mesh_shape in process_mesh_shapes:
                    dist_context_of_device_meshes = None
                    for idx, device_mesh in enumerate(device_meshes):
                        device_mesh_shape = device_mesh.shape
                        process_mesh = ProcessMesh(
                            np.array(device_mesh.device_ids)
                            .reshape(process_mesh_shape)
                            .tolist()
                        )

                        selective_parallelisms = (
                            ["dp", "mp"]
                            if len(process_mesh.shape) == 1
                            else ["dp_mp", "mp_dp"]
                        )
                        if parallelism not in selective_parallelisms:
                            total_cost_of_device_meshes = sys.maxsize
                            continue

                        key = self.convert_process_mesh_to_key(process_mesh)

                        if idx == len(device_meshes) - 1:
                            start = idx * average_layers
                            end = len(self.layers)
                        else:
                            start = idx * average_layers
                            end = (idx + 1) * average_layers

                        dist_context = self.combine_dist_contexts(
                            [
                                self.sub_programs_dist_context[j][parallelism][
                                    key
                                ]
                                for j in range(start, end)
                            ]
                        )

                        dist_context_of_device_meshes = (
                            dist_context
                            if dist_context_of_device_meshes is None
                            else self.combine_dist_contexts(
                                [dist_context_of_device_meshes, dist_context]
                            )
                        )
                    if dist_context_of_device_meshes is not None:
                        cost, memory = self._get_sub_program_cost(
                            dist_context_of_device_meshes
                        )

                        self._logger.info(
                            "Cost Model: The max memory is {}GB and cost is {} when {} parallelism under process mesh shape {} on {} stages.".format(
                                memory / (1024**3),
                                cost,
                                parallelism,
                                process_mesh_shape,
                                len(device_meshes),
                            )
                        )
                        # 15% buffer is reserved for memory cost
                        if memory > 0.85 * self.cluster.machines[0].devices[
                            0
                        ].memory * (1024**3):
                            cost = sys.maxsize

                        if cost < best_cost:
                            best_cost = cost
                            best_dist_context = dist_context_of_device_meshes
                            self._logger.info(
                                "O1 level: a better strategy has be found that parallelism is {} under process mesh shape {} on {} stages with max memory {}GB.".format(
                                    parallelism,
                                    process_mesh_shape,
                                    len(device_meshes),
                                    memory / (1024**3),
                                )
                            )

        return best_dist_context

    def tune_o2(self):
        return None

    def save_strategy(self, best_dist_context, path):
        dist_attrs = {"tensor": {}, "op": {}, "process_meshes": []}
        for key in best_dist_context._dist_tensors_for_program:
            if key in self._dist_context._dist_tensors_for_program:
                dist_tensor = best_dist_context._dist_tensors_for_program[key]
                dist_attrs["tensor"][
                    key
                ] = dist_tensor.dist_attr.serialize_to_string()
        assert dist_attrs["tensor"], "Tensor dist attrs must not be None."

        for key in best_dist_context._dist_ops_for_program:
            if key in self._dist_context._dist_ops_for_program:
                dist_op = best_dist_context._dist_ops_for_program[key]
                dist_attrs["op"][key] = dist_op.dist_attr.serialize_to_string()
        assert dist_attrs["op"], "Op dist attrs must not be None."

        for process_mesh in best_dist_context._process_meshes:
            process_ids = process_mesh.process_ids
            process_shape = process_mesh.shape
            dist_attrs["process_meshes"].append([process_ids, process_shape])

        dist_attrs["cluster"] = self._cluster
        with open(path, 'wb') as f:
            pickle.dump(dist_attrs, f)
        self._logger.info("The strategy has been saved at {}".format(path))

    def run_or_quit(self):
        # Quit if just tune
        if not self._is_run:
            self._logger.info(
                "The process will be quitted when just tune not run."
            )
            quit()

    def tune(self):
        begin = time.time()
        self.match_program(self._dist_context.serial_main_program)
        end = time.time()
        self._logger.info("Pattern match in {}s.".format(end - begin))

        if self._use_dp:
            completer = Completer(self._dist_context)
            completer.complete_forward_annotation()
            print_program_with_dist_attr(
                self._dist_context.serial_main_program, self._dist_context
            )
            # Save strategy if need
            path = self._strategy_path
            if path:
                self.save_strategy(self._dist_context, path)
                self.run_or_quit()
            return

        # prepare
        self.prepare()

        best_dist_context = None
        if self.level == "o2":
            best_dist_context = self.tune_o2()

        elif self.level == "o1":
            # If level is o1, it means all layers within same parallelism.
            # When in pipeline parallism, it means that place layers evenly.
            use_o2_level = False
            for device_meshes in self.device_meshes_list:
                if len(device_meshes) > 1:
                    shape = None
                    for device_mesh in device_meshes:
                        if shape is None:
                            shape = device_mesh.shape
                            continue
                        else:
                            if shape != device_mesh.shape:
                                self._logger.info(
                                    "Warning: The o1 level is not be supported when the number of machines is prime numer which greaters than 1. We will use o2 level to tune."
                                )
                                use_o2_level = True
                                break
            if use_o2_level:
                best_dist_context = self.tune_o2()
            else:
                best_dist_context = self.tune_o1()

        assert (
            best_dist_context is not None
        ), "can not find a parallel strategy to run, please use passes such as recompute, amp or sharding."

        for key in best_dist_context._dist_tensors_for_program:
            if key in self._dist_context._dist_tensors_for_program:
                self._dist_context._dist_tensors_for_program[
                    key
                ] = best_dist_context._dist_tensors_for_program[key]
        for key in best_dist_context._dist_ops_for_program:
            if key in self._dist_context._dist_ops_for_program:
                self._dist_context._dist_ops_for_program[
                    key
                ] = best_dist_context._dist_ops_for_program[key]
        self._dist_context._process_meshes = best_dist_context._process_meshes

        end = time.time()
        self._logger.info("Rule-based tuner end in {}s.".format(end - begin))
        self._logger.info("The best strategy found is as follows: ")
        print_program_with_dist_attr(self.full_main_program, best_dist_context)

        # Save strategy if need
        path = self._strategy_path
        if path:
            self.save_strategy(best_dist_context, path)
            self.run_or_quit()