parallel_tuner.py 43.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import math
import copy
import hashlib
import itertools
from collections import defaultdict
import numpy as np
from ..process_mesh import ProcessMesh
from ..completion import Completer
from ..parallelizer_v2 import Parallelizer
from ..dist_context import _node_id
from ..dist_op import DistributedOperator
from ..operators.common import find_compatible_distributed_operator_impls
from .trial import Trial, TrialStatus
from .tunable_space import TunableSpace
from .tunable_variable import Boolean, IntRange
from ..cost import CostEstimator
from .tunable_variable import Boolean, IntRange


class ParallelTuner:
36 37 38 39 40 41 42 43 44 45
    def __init__(
        self,
        dist_context,
        mode="train",
        max_trials=25,
        tuner_id=None,
        seed=None,
        logger=None,
        loop_count=10,
    ):
46 47 48 49 50 51 52
        self._loop_count = loop_count
        self._estimator = None
        self._dist_context = dist_context
        assert self._dist_context._is_initialized
        self._mode = mode
        self._cluster = self._dist_context.cluster
        self._num_machines = self._cluster.get_num_machines()
53 54
        self._num_devices_per_machine = (
            self._cluster.get_num_devices_per_machine()
55 56 57 58 59 60 61 62
        )
        self._space = TunableSpace()
        self._objective = "time"
        self._direction = "min"
        self._max_trials = max_trials
        self._tuner_id = tuner_id
        self._seed = seed if seed is not None else 9999

63 64 65 66 67 68 69 70 71 72 73
        print(
            "seed",
            self._seed,
            "mode",
            self._mode,
            "num_machies",
            self._num_machines,
            "num_devices_per_machine",
            self._num_devices_per_machine,
            flush=True,
        )
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
        self._seed_state = self._seed
        self._logger = logger
        self._max_collisions = 3
        self._tried_values = set()
        self._num_trials = 0
        self._rng = np.random.default_rng(self._seed)

        # Search the op types in the include_op_types,
        # and will search all op types if it is empty.
        # Exclude the op types in the exclude_op_types
        # from the search list.
        self._exclude_op_types = []
        self._include_op_types = []
        # The final dist ops will be searched after considering
        # the include_op_types and exclude_op_types.
        self._concerned_dist_ops = {}

        self._op_id_to_dist_attr_candidates = defaultdict(list)
        self._cached_dims_mapping_candidates = {}
        self._cached_candidates_info = defaultdict(list)

        self._special_ops = [
96 97 98 99 100 101
            "create_py_reader",
            "create_double_buffer_reader",
            "read",
            "while",
            "read_from_array",
            "write_to_array",
102 103 104 105 106 107 108 109 110
        ]

        # Each parallel strategy has two elements. The First one is for distributed tensors,
        # the second element is for distributed tensors, the third element is for process meshes.
        self._init_parallel_strategy = [None, None, None]
        self._best_parallel_strategy = [None, None, None]

        self._completer = Completer(self._dist_context)

111 112 113
        self._parallelizer = Parallelizer(
            self._mode, self._completer, self._dist_context
        )
114

115 116 117 118 119 120 121 122 123
    def _generate_combination(
        self,
        elements,
        target,
        idx,
        partial_candidate,
        candidates,
        num_candidates=None,
    ):
124 125 126 127
        if target == 0:
            candidates.append(copy.deepcopy(partial_candidate))
            return

128 129 130 131 132
        if (
            target < 0
            or idx == len(elements)
            or len(candidates) > num_candidates
        ):
133 134 135 136
            return

        # Use
        partial_candidate.append(elements[idx])
137 138 139 140 141 142 143 144
        self._generate_combination(
            elements,
            target - elements[idx],
            idx,
            partial_candidate,
            candidates,
            num_candidates,
        )
145 146
        # Not use
        partial_candidate.pop()
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
        self._generate_combination(
            elements,
            target,
            idx + 1,
            partial_candidate,
            candidates,
            num_candidates,
        )

    def _permute_combination(
        self,
        combination,
        target,
        check,
        partial_candidate,
        candidates,
        num_candidates=None,
        skip_prob=None,
    ):
        if num_candidates is not None and len(candidates) == num_candidates:
167 168 169 170 171 172 173 174 175 176 177
            return

        if len(partial_candidate) == len(combination):
            candidates.append(partial_candidate)
            return

        for i in range(len(combination)):
            if check[i] == 1:
                continue
            if self._rng.choice([True, False], p=[skip_prob, 1 - skip_prob]):
                continue
178 179 180 181 182
            if (
                i > 0
                and combination[i] == combination[i - 1]
                and check[i - 1] == 0
            ):
183 184
                continue
            check[i] = 1
185 186 187 188 189 190 191 192 193
            self._permute_combination(
                combination,
                target,
                check,
                partial_candidate + [combination[i]],
                candidates,
                num_candidates,
                skip_prob,
            )
194 195 196 197 198 199 200 201 202 203
            check[i] = 0

    def _partition_number(self, target):
        log2_target = int(math.log2(target))
        elements = [pow(2, i) for i in range(log2_target)]
        if pow(2, log2_target) == target:
            elements.append(target)
        seed_candidates = []
        num_seed_candidates = 1000
        partial_results = []
204 205 206 207 208 209 210 211
        self._generate_combination(
            elements,
            target,
            0,
            partial_results,
            seed_candidates,
            num_seed_candidates,
        )
212 213 214 215 216 217 218 219 220 221

        candidates = []
        for seed_candidate in seed_candidates:
            cur_candidates = []
            num_cur_candidates = 16
            seed_candidate.sort()
            check = [0 for i in range(len(seed_candidate))]
            if target <= 8:
                skip_prob = 0.0
            else:
222 223 224 225 226 227 228 229 230 231
                skip_prob = len(seed_candidate) / target
            self._permute_combination(
                seed_candidate,
                target,
                check,
                [],
                cur_candidates,
                num_cur_candidates,
                skip_prob,
            )
232 233 234 235 236 237 238 239
            candidates.extend(cur_candidates)
        return candidates

    def _partition_devices(self, num_machines, num_devices_per_machine):
        inter_node_partitions = self._partition_number(num_machines)
        intra_node_partitions = self._partition_number(num_devices_per_machine)
        return inter_node_partitions, intra_node_partitions

240 241 242
    def _generate_process_mesh_list(
        self, inter_node_partition, intra_node_partition
    ):
243 244 245 246 247 248 249 250
        process_mesh_list = []
        start_row = 0
        start_col = 0
        for m in inter_node_partition:
            start_col = 0
            for n in intra_node_partition:
                process_mesh = []
                for p in range(m):
251 252 253
                    start = (
                        start_row + p
                    ) * self._num_devices_per_machine + start_col
254 255 256 257 258 259 260 261 262
                    tmp = []
                    for q in range(n):
                        tmp.append(start + q)
                    process_mesh.append(tmp)
                process_mesh_list.append(copy.deepcopy(process_mesh))
                start_col += n
            start_row += m
        return process_mesh_list

263 264 265
    def _generate_dims_mapping_candidates_helper(
        self, dims_mapping, dims_list, start, visited, candidates
    ):
266 267 268 269 270 271 272 273 274
        if start == len(dims_mapping) or all(visited):
            candidates.append(copy.deepcopy(dims_mapping))
            return

        for idx, dim in enumerate(dims_list):
            if visited[idx] == False:
                dims_mapping[start] = dim
                visited[idx] = True
                self._generate_dims_mapping_candidates_helper(
275 276
                    dims_mapping, dims_list, start + 1, visited, candidates
                )
277 278
                visited[idx] = False
        dims_mapping[start] = -1
279 280 281
        self._generate_dims_mapping_candidates_helper(
            dims_mapping, dims_list, start + 1, visited, candidates
        )
282

283 284 285
    def _generate_dims_mapping_candidates(
        self, dims_mapping_len, process_mesh_len
    ):
286 287 288 289 290 291 292 293
        assert dims_mapping_len >= 1 and process_mesh_len >= 1
        key = (dims_mapping_len, process_mesh_len)
        if key in self._cached_dims_mapping_candidates:
            return self._cached_dims_mapping_candidates[key]
        candidates = []
        dims_mapping = [-1 for i in range(dims_mapping_len)]
        dims_list = [i for i in range(process_mesh_len)]
        visited = [False for i in range(process_mesh_len)]
294 295 296
        self._generate_dims_mapping_candidates_helper(
            dims_mapping, dims_list, 0, visited, candidates
        )
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
        self._cached_dims_mapping_candidates[key] = candidates
        return candidates

    def _generate_dist_attr_candidates(self, op_id, dist_op):
        # For now, only allow the process meshes have two dimensions
        process_mesh_len = 2
        serial_op = dist_op.serial_op
        op_dist_attr = dist_op.dist_attr
        if serial_op.type in self._special_ops:
            return [copy.deepcopy(op_dist_attr)]
        key = []
        key.append(serial_op.type)
        for input_name in serial_op.input_names:
            key.append(input_name)
            for input_arg_name in serial_op.input(input_name):
                key.append(
313 314
                    len(op_dist_attr.get_input_dims_mapping(input_arg_name))
                )
315 316 317 318
        for output_name in serial_op.output_names:
            key.append(output_name)
            for output_arg_name in serial_op.output(output_name):
                key.append(
319 320
                    len(op_dist_attr.get_output_dims_mapping(output_arg_name))
                )
321 322 323 324 325 326 327 328 329 330 331
        key = tuple(key)

        if key in self._cached_candidates_info:
            cached_dist_attr_candidates = []
            cached_input_arg_names = self._cached_candidates_info[key][0]
            cached_output_arg_names = self._cached_candidates_info[key][1]
            for cached_dist_attr in self._cached_candidates_info[key][2]:
                new_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
                i = 0
                for input_name in serial_op.input_names:
                    for input_arg_name in serial_op.input(input_name):
332 333 334 335 336
                        cached_dims_mapping = (
                            cached_dist_attr.get_input_dims_mapping(
                                cached_input_arg_names[i]
                            )
                        )
337
                        new_op_dist_attr.set_input_dims_mapping(
338 339
                            input_arg_name, cached_dims_mapping
                        )
340 341 342 343
                        i += 1
                i = 0
                for output_name in serial_op.output_names:
                    for output_arg_name in serial_op.output(output_name):
344 345 346 347 348
                        cached_dims_mapping = (
                            cached_dist_attr.get_output_dims_mapping(
                                cached_output_arg_names[i]
                            )
                        )
349
                        new_op_dist_attr.set_output_dims_mapping(
350 351
                            output_arg_name, cached_dims_mapping
                        )
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380
                        i += 1
                cached_dist_attr_candidates.append(new_op_dist_attr)
            return cached_dist_attr_candidates

        # cached_candidates_info = []
        input_arg_names = []
        for input_name in serial_op.input_names:
            for input_arg_name in serial_op.input(input_name):
                input_arg_names.append(input_arg_name)
        self._cached_candidates_info[key].append(input_arg_names)
        # cached_candidates_info.append(input_arg_names)
        output_arg_names = []
        for output_name in serial_op.output_names:
            for output_arg_name in serial_op.output(output_name):
                output_arg_names.append(output_arg_name)
        self._cached_candidates_info[key].append(output_arg_names)
        # cached_candidates_info.append(output_arg_names)

        new_op_dist_attr = copy.deepcopy(dist_op.dist_attr)
        # Find valid dims_mapping candidates for inputs
        input_names = []
        dims_mapping_generated = []
        inputs_dist_attrs = op_dist_attr.inputs_dist_attrs
        for tensor_name, tensor_dist_attr in inputs_dist_attrs.items():
            original_dims_mapping = tensor_dist_attr.dims_mapping
            dims_mapping_len = len(original_dims_mapping)
            input_names.append(tensor_name)
            if dims_mapping_len < 1:
                dims_mapping_generated.append(
381 382
                    [copy.deepcopy(original_dims_mapping)]
                )
383 384 385
            else:
                dims_mapping_generated.append(
                    self._generate_dims_mapping_candidates(
386 387 388
                        dims_mapping_len, process_mesh_len
                    )
                )
389 390 391 392 393
        input_dims_mapping_candidates = []
        for dims_mapping_list in itertools.product(*dims_mapping_generated):
            dims_mapping_list = list(dims_mapping_list)
            assert len(dims_mapping_list) == len(input_names)
            for i, dims_mapping in enumerate(dims_mapping_list):
394 395 396 397 398 399
                new_op_dist_attr.set_input_dims_mapping(
                    input_names[i], dims_mapping
                )
            new_dist_op = DistributedOperator(
                dist_op.serial_op, new_op_dist_attr
            )
400
            dist_op_impls = find_compatible_distributed_operator_impls(
401 402
                new_dist_op, fwd=True
            )
403 404 405 406 407 408 409 410 411 412 413 414 415
            if dist_op_impls is not None:
                input_dims_mapping_candidates.append(dims_mapping_list)

        # Find valid dims_mapping candidates for outputs
        output_names = []
        dims_mapping_generated = []
        outputs_dist_attrs = op_dist_attr.outputs_dist_attrs
        for tensor_name, tensor_dist_attr in outputs_dist_attrs.items():
            original_dims_mapping = tensor_dist_attr.dims_mapping
            dims_mapping_len = len(original_dims_mapping)
            output_names.append(tensor_name)
            if dims_mapping_len < 1:
                dims_mapping_generated.append(
416 417
                    [copy.deepcopy(original_dims_mapping)]
                )
418 419 420
            else:
                dims_mapping_generated.append(
                    self._generate_dims_mapping_candidates(
421 422 423
                        dims_mapping_len, process_mesh_len
                    )
                )
424 425 426 427 428 429
        output_dims_mapping_candidates = []
        for dims_mapping_list in itertools.product(*dims_mapping_generated):
            dims_mapping_list = list(dims_mapping_list)
            assert len(dims_mapping_list) == len(output_names)
            for i, dims_mapping in enumerate(dims_mapping_list):
                new_op_dist_attr.set_output_dims_mapping(
430 431 432 433 434
                    output_names[i], dims_mapping
                )
            new_dist_op = DistributedOperator(
                dist_op.serial_op, new_op_dist_attr
            )
435
            dist_op_impls = find_compatible_distributed_operator_impls(
436 437
                new_dist_op, fwd=False
            )
438 439 440 441 442
            if dist_op_impls is not None:
                output_dims_mapping_candidates.append(dims_mapping_list)

        if not input_dims_mapping_candidates and output_dims_mapping_candidates:
            inout_dims_mapping_generated = [
443 444
                [[[-2]]],
                output_dims_mapping_candidates,
445
            ]
446 447 448 449 450 451 452 453 454 455 456
        elif (
            input_dims_mapping_candidates and not output_dims_mapping_candidates
        ):
            inout_dims_mapping_generated = [
                input_dims_mapping_candidates,
                [[[-2]]],
            ]
        elif (
            not input_dims_mapping_candidates
            and not output_dims_mapping_candidates
        ):
457 458 459
            inout_dims_mapping_generated = [[[[-2]]], [[[-2]]]]
        else:
            inout_dims_mapping_generated = [
460 461
                input_dims_mapping_candidates,
                output_dims_mapping_candidates,
462 463 464 465
            ]
        # Find valid dims_mapping generated for both inputs and outputs
        cached_dist_attr_candidates = []
        for inout_dims_mapping_list in itertools.product(
466 467
            *inout_dims_mapping_generated
        ):
468 469 470 471 472 473 474 475 476
            assert len(inout_dims_mapping_list) == 2
            if input_dims_mapping_candidates:
                assert len(inout_dims_mapping_list[0]) == len(input_names)
            if output_dims_mapping_candidates:
                assert len(inout_dims_mapping_list[1]) == len(output_names)
            # set the dims_mappings for inputs
            for i, dims_mapping in enumerate(inout_dims_mapping_list[0]):
                if dims_mapping != [-2]:
                    new_op_dist_attr.set_input_dims_mapping(
477 478
                        input_names[i], dims_mapping
                    )
479 480 481 482
            # set the dims_mappings for outputs
            for i, dims_mapping in enumerate(inout_dims_mapping_list[1]):
                if dims_mapping != [-2]:
                    new_op_dist_attr.set_output_dims_mapping(
483 484 485 486 487
                        output_names[i], dims_mapping
                    )
            new_dist_op = DistributedOperator(
                dist_op.serial_op, new_op_dist_attr
            )
488
            dist_op_impls = find_compatible_distributed_operator_impls(
489 490
                new_dist_op, partial=False
            )
491 492 493 494 495 496
            if dist_op_impls is None:
                continue
            for dist_op_impl in dist_op_impls:
                new_op_dist_attr.impl_type = dist_op_impl.type
                new_op_dist_attr.impl_idx = dist_op_impl.idx
                cached_dist_attr_candidates.append(
497 498
                    copy.deepcopy(new_op_dist_attr)
                )
499 500 501 502 503
        self._cached_candidates_info[key].append(cached_dist_attr_candidates)
        return self._cached_candidates_info[key][2]

    def construct_space(self):
        inter_node_partitions, intra_node_partitions = self._partition_devices(
504 505 506 507 508 509 510 511 512 513 514 515
            self._num_machines, self._num_devices_per_machine
        )
        self._space.choice(
            "inter_node_partitions",
            inter_node_partitions,
            default=inter_node_partitions[0],
        )
        self._space.choice(
            "intra_node_partitions",
            intra_node_partitions,
            default=intra_node_partitions[0],
        )
516 517 518 519 520 521 522 523 524 525 526 527 528 529 530

        dist_ops = self._dist_context._dist_ops_for_program
        for op_id, dist_op in dist_ops.items():
            op_type = dist_op.serial_op.type
            if self._include_op_types:
                if op_type in self._include_op_types:
                    self._concerned_dist_ops[op_id] = dist_op
            else:
                self._concerned_dist_ops[op_id] = dist_op

        for op_id, dist_op in self._concerned_dist_ops.items():
            op_type = dist_op.serial_op.type
            if op_type in self._exclude_op_types:
                del self._concerned_dist_ops[op_id]

531 532 533 534 535
        print(
            "Number of the concered dist ops",
            len(self._concerned_dist_ops),
            flush=True,
        )
536 537 538
        search_space = 1
        for op_id, dist_op in self._concerned_dist_ops.items():
            op_dist_attr_candidates = self._generate_dist_attr_candidates(
539 540
                op_id, dist_op
            )
541
            search_space *= len(op_dist_attr_candidates)
542 543 544 545 546
            self._space.choice(
                str(op_id),
                op_dist_attr_candidates,
                default=op_dist_attr_candidates[0],
            )
547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624

    def _compute_values_hash(self, values):
        keys = sorted(values.keys())
        s = "".join(str(k) + "=" + str(values[k]) for k in keys)
        return hashlib.sha256(s.encode("utf-8")).hexdigest()[:32]

    def _random_values(self):
        space = TunableSpace()
        collisions = 0
        while True:
            for v in self._space.variables.values():
                space._register(v)
                space.values[v.name] = v.random(self._seed_state)
                self._seed_state += 1
            values = space.values
            values_hash = self._compute_values_hash(values)
            if values_hash in self._tried_values:
                collisions += 1
                if collisions > self._max_collisions:
                    return None
                continue
            self._tried_values.add(values_hash)
            break
        return values

    def _populate_space(self):
        values = self._random_values()
        if values is None:
            return {"status": TrialStatus.STOPPED, "values": None}
        return {"status": TrialStatus.RUNNING, "values": values}

    def _create_trial(self):
        trial_id = "{{:0{}d}}".format(len(str(self._max_trials)))
        trial_id = trial_id.format(self._num_trials)

        if self._max_trials and self._num_trials >= self._max_trials:
            status = TrialStatus.STOPPED
            values = None
        else:
            results = self._populate_space()
            status = results["status"]
            values = results["values"]

        space = TunableSpace()
        space.variables = self._space.variables
        space.values = values
        trial = Trial(tunable_space=space, trial_id=trial_id, status=status)
        self._num_trials += 1
        return trial

    def _generate_pipeline_starts(self, process_mesh_list):
        total_ops = len(self._dist_context._dist_ops_for_program)
        total_stages = len(process_mesh_list)
        ops_per_stage = total_ops // total_stages
        if ops_per_stage == 0:
            return None
        # Compute the initial pipeline starts
        pipeline_starts = []
        start = 0
        pipeline_starts.append(0)
        # The pipeline_starts have total_stages+1 items, and
        # at least have 2 items.
        for _ in process_mesh_list:
            start += ops_per_stage
            pipeline_starts.append(start)
        pipeline_starts[-1] = total_ops
        # Adjust the pipeline starts by random selection
        directions = []
        sizes = []
        half_ops_per_stage = ops_per_stage // 2
        if half_ops_per_stage > 0 and total_stages > 1:
            new_pipeline_starts = []
            # Don't change the first start
            new_pipeline_starts.append(0)
            # Consider the starts except the first and the last one
            for _ in pipeline_starts[1:-1]:
                directions.append(Boolean("direction"))
                sizes.append(
625 626 627 628
                    IntRange(
                        "size", start=0, stop=half_ops_per_stage, endpoint=True
                    )
                )
629 630 631 632 633 634 635 636 637 638 639 640
            for i, start in enumerate(pipeline_starts[1:-1]):
                direction = directions[i].random(self._seed)
                size = sizes[i].random(self._seed)
                if direction:
                    # Substract 1 from size to avoid the overlapping of new starts
                    new_start = start - (size - 1)
                else:
                    new_start = start + size
                new_pipeline_starts.append(new_start)
            # Don't change the last start
            new_pipeline_starts.append(pipeline_starts[-1])
            # Validate the new starts
641 642 643 644 645 646 647
            print(
                "Adjusted pipeline starts",
                new_pipeline_starts,
                half_ops_per_stage,
                pipeline_starts,
                flush=True,
            )
648 649 650 651
            for i, new_start in enumerate(new_pipeline_starts[1:]):
                assert new_start > new_pipeline_starts[i]
            return new_pipeline_starts
        else:
652 653 654 655 656 657
            print(
                "Non-adjusted pipeline starts",
                pipeline_starts,
                half_ops_per_stage,
                flush=True,
            )
658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726
            return pipeline_starts

    def _apply_pipeline_partition(self, process_mesh_list):
        op_id_to_process_mesh = {}
        total_ops = len(self._dist_context._dist_ops_for_program)
        total_stages = len(process_mesh_list)
        ops_per_stage = total_ops // total_stages
        if ops_per_stage == 0:
            return None
        pipeline_starts = self._generate_pipeline_starts(process_mesh_list)
        start_idx = 1
        sorted_op_ids = sorted(self._dist_context._dist_ops_for_program.keys())
        for idx, op_id in enumerate(sorted_op_ids):
            if idx < pipeline_starts[start_idx]:
                op_id_to_process_mesh[op_id] = process_mesh_list[start_idx - 1]
            else:
                start_idx += 1
                op_id_to_process_mesh[op_id] = process_mesh_list[start_idx - 1]
        return op_id_to_process_mesh

    def _amend_dist_attr(self):
        # 1) Reshape the process mesh of [1, x] to [x] or [x, 1] to [x],
        # and amend the corresponding dims_mapping.
        # 2) Set the dim_mapping to -1 when the shape cannot be divided
        # by the corresponding processes.
        for dist_op in self._dist_context._dist_ops_for_program.values():
            dist_attr = dist_op.dist_attr
            process_mesh = dist_attr.process_mesh
            if process_mesh is None:
                continue
            assert process_mesh.ndim == 2
            dim_of_one = None
            dim_of_other = None
            if process_mesh.topology[0] == 1:
                dim_of_one = 0
                dim_of_other = 1
            elif process_mesh.topology[1] == 1:
                dim_of_one = 1
                dim_of_other = 0

            if dim_of_one is not None:
                dist_attr.process_mesh = ProcessMesh(process_mesh.processes)
                self._dist_context.add_process_mesh(dist_attr.process_mesh)

            for arg_name in dist_attr.inputs_dist_attrs.keys():
                new_dims_mapping = []
                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                for dim_mapping in dims_mapping:
                    if dim_mapping == dim_of_one:
                        new_dims_mapping.append(-1)
                    elif dim_mapping == dim_of_other:
                        new_dims_mapping.append(0)
                    else:
                        new_dims_mapping.append(dim_mapping)
                dist_attr.set_input_dims_mapping(arg_name, new_dims_mapping)

                dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                # dynamic_dims = dist_attr.get_input_dynamic_dims(arg_name)
                process_mesh = dist_attr.process_mesh
                process_shape = process_mesh.topology
                tensor = dist_op.get_serial_input(arg_name)
                if dims_mapping:
                    tensor_shape = tensor.shape
                else:
                    continue
                for i, dim_mapping in enumerate(dims_mapping):
                    # if dim_mapping != -1 \
                    #     and (tensor_shape[i] % process_shape[dim_mapping] != 0 \
                    #     or dynamic_dims[i] == 1):
727 728 729
                    if dim_mapping != -1 and (
                        tensor_shape[i] % process_shape[dim_mapping] != 0
                    ):
730 731
                        dims_mapping[i] = -1
                    # it is a fix-bug
732
                    if dim_mapping != -1 and process_shape[dim_mapping] == 1:
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757
                        dims_mapping[i] = -1

            for arg_name in dist_attr.outputs_dist_attrs.keys():
                new_dims_mapping = []
                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                for dim_mapping in dims_mapping:
                    if dim_mapping == dim_of_one:
                        new_dims_mapping.append(-1)
                    elif dim_mapping == dim_of_other:
                        new_dims_mapping.append(0)
                    else:
                        new_dims_mapping.append(dim_mapping)
                dist_attr.set_output_dims_mapping(arg_name, new_dims_mapping)

                dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                # dynamic_dims = dist_attr.get_output_dynamic_dims(arg_name)
                process_mesh = dist_attr.process_mesh
                process_shape = process_mesh.topology

                tensor = dist_op.get_serial_output(arg_name)
                if dims_mapping:
                    tensor_shape = tensor.shape
                else:
                    continue
                for i, dim_mapping in enumerate(dims_mapping):
758 759 760
                    if dim_mapping != -1 and (
                        tensor_shape[i] % process_shape[dim_mapping] != 0
                    ):
761 762
                        dims_mapping[i] = -1
                    # it is a fix-bug
763
                    if dim_mapping != -1 and process_shape[dim_mapping] == 1:
764 765
                        dims_mapping[i] = -1
            dist_op_impls = find_compatible_distributed_operator_impls(
766 767
                dist_op, partial=False
            )
768 769 770
            serial_op_type = dist_op.serial_op.type

            if dist_op_impls is not None and (
771 772 773
                serial_op_type != "fused_softmax_mask_upper_triangle"
                or self._check_fused_softmax_mask_upper_triangle(dist_op)
            ):
774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792
                dist_op.dist_attr.impl_type = dist_op_impls[0].type
                dist_op.dist_attr.impl_idx = dist_op_impls[0].idx
            else:
                # Use the default dist op impl
                for arg_name in dist_attr.inputs_dist_attrs.keys():
                    dims_mapping = dist_attr.get_input_dims_mapping(arg_name)
                    for i, _ in enumerate(dims_mapping):
                        dims_mapping[i] = -1
                for arg_name in dist_attr.outputs_dist_attrs.keys():
                    dims_mapping = dist_attr.get_output_dims_mapping(arg_name)
                    for i, _ in enumerate(dims_mapping):
                        dims_mapping[i] = -1
                dist_op.dist_attr.impl_type = "default"
                dist_op.dist_attr.impl_idx = 0

    def _check_fused_softmax_mask_upper_triangle(self, dist_op):
        """The last_but_one dim shoule be equal to last dim."""
        input_name = dist_op.serial_op.input_arg_names[0]
        input_dims_mapping = dist_op.dist_attr.get_input_dims_mapping(
793 794
            input_name
        )
795 796
        topology = dist_op.dist_attr.process_mesh.topology
        input_tensor = dist_op.get_serial_input(input_name)
797 798 799 800 801 802 803 804 805 806
        last_but_one_dim = (
            input_tensor.shape[-2] // topology[input_dims_mapping[-2]]
            if input_dims_mapping[-2] != -1
            else input_tensor.shape[-2]
        )
        last_dim = (
            input_tensor.shape[-1] // topology[input_dims_mapping[-1]]
            if input_dims_mapping[-1] != -1
            else input_tensor.shape[-1]
        )
807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823
        if last_but_one_dim == last_dim:
            return True
        return False

    def _eval_trial(self, trial):
        if self._num_trials == 0:
            num_prev_trials = 0
        else:
            num_prev_trials = self._num_trials - 1

        results = None

        start_time = time.time()

        inter_node_partition = trial.space.values["inter_node_partitions"]
        intra_node_partition = trial.space.values["intra_node_partitions"]
        process_mesh_list = self._generate_process_mesh_list(
824 825
            inter_node_partition, intra_node_partition
        )
826 827
        print("\tprocess_mesh list", process_mesh_list, flush=True)
        op_id_to_process_mesh = self._apply_pipeline_partition(
828 829
            process_mesh_list
        )
830 831 832 833 834 835
        if op_id_to_process_mesh is None:
            print("Operators are less than pipeline stages", flush=True)
            return results

        op_id_to_dist_attr = {}
        for name, value in trial.space.values.items():
836 837 838 839
            if (
                name != "inter_node_partitions"
                and name != "intra_node_partitions"
            ):
840 841 842 843
                op_id_to_dist_attr[int(name)] = value

        end_time = time.time()
        cur_sample_time = end_time - start_time
844 845 846 847 848 849 850 851 852 853 854
        self._sample_time = (
            num_prev_trials * self._sample_time + cur_sample_time
        ) / self._num_trials
        print(
            "\tsample_time",
            num_prev_trials,
            self._num_trials,
            self._sample_time,
            cur_sample_time,
            flush=True,
        )
855 856 857 858 859 860 861

        assert len(op_id_to_process_mesh) == len(op_id_to_dist_attr)

        start_time = time.time()
        for op_id, process_mesh in op_id_to_process_mesh.items():
            dist_op = self._dist_context._dist_ops_for_program[op_id]
            dist_op.dist_attr = copy.deepcopy(op_id_to_dist_attr[op_id])
862 863 864 865 866 867 868
            assert (
                dist_op.dist_attr.impl_type
                == op_id_to_dist_attr[op_id].impl_type
            )
            assert (
                dist_op.dist_attr.impl_idx == op_id_to_dist_attr[op_id].impl_idx
            )
869 870 871 872 873 874
            dist_op.dist_attr.process_mesh = process_mesh
        self._amend_dist_attr()

        self._completer._complete_tensor_dist_attr_by_op()

        self._dist_context.block_state.parse_forward_blocks(
875 876
            self._dist_context.serial_main_program
        )
877 878 879

        end_time = time.time()
        cur_complete_time = end_time - start_time
880 881 882 883 884 885 886 887 888 889 890
        self._complete_time = (
            num_prev_trials * self._complete_time + cur_complete_time
        ) / self._num_trials
        print(
            "\tcomplete_time",
            num_prev_trials,
            self._num_trials,
            self._complete_time,
            cur_complete_time,
            flush=True,
        )
891 892 893 894 895

        start_time = time.time()
        estimate_time = self._estimate_trial()
        end_time = time.time()
        cur_estimate_time = end_time - start_time
896 897 898 899 900 901 902 903 904 905 906 907
        self._estimate_time = (
            num_prev_trials * self._estimate_time + cur_estimate_time
        ) / self._num_trials
        print(
            "\testimate_time",
            num_prev_trials,
            self._num_trials,
            self._estimate_time,
            cur_estimate_time,
            estimate_time,
            flush=True,
        )
908 909 910 911 912 913 914 915 916 917 918 919 920 921 922

        results = {"estimate_time": estimate_time}
        return results

    def _update_trail(self, trial, metrics, step=0):
        for metric_name, metric_value in metrics.items():
            trial.recorder.update(metric_name, metric_value, step=step)
        return trial.status

    def _estimate_trial(self):
        assert self._cluster is not None
        if self._mode == "eval":
            self._estimator = CostEstimator(
                self._dist_context.serial_main_program,
                self._cluster,
923 924
                loop_count=self._loop_count,
            )
925 926 927 928
        elif self._mode == "predict":
            self._estimator = CostEstimator(
                self._dist_context.serial_main_program,
                self._cluster,
929 930
                loop_count=self._loop_count,
            )
931 932 933 934 935 936 937 938 939
        elif self._mode == "train":
            # get serial main program with backward
            serial_main_program = self._dist_context.serial_main_program
            serial_startup_program = self._dist_context.serial_startup_program
            serial_optimizer = self._dist_context.serial_optimizer

            # Generate backward
            serial_loss = self._dist_context.serial_fetch_vars["loss"][0]
            params_grads = self._parallelizer._generate_backward(
940 941
                serial_main_program, serial_startup_program, serial_loss
            )
942 943 944

            # Generate optimizer
            optimizer_ops = self._parallelizer._generate_optimizer(
945 946 947 948 949 950 951 952
                serial_main_program,
                serial_startup_program,
                serial_optimizer,
                params_grads,
            )
            self._estimator = CostEstimator(
                serial_main_program, self._cluster, loop_count=self._loop_count
            )
953 954

        max_memory = self._estimator._estimate_max_memory_by_dist_op(
955 956
            self._dist_context
        )
957 958 959 960 961 962 963 964 965 966 967
        print("\tmax_memory", "{:,}".format(max_memory), flush=True)
        # The max memory must be less than 80% 32GB (hard code)
        if max_memory > 32 * 0.8 * 1024 * 1024 * 1024:
            return math.inf
        else:
            global_cost = self._estimator.estimate(self._dist_context)
            return global_cost.time

    def _store_init_parallel_strategy(self):
        # If there is no annotation information, use the dp as the initial parallel strategy.
        # TODO: we should need a better way to set up the initial parallel strategy.
968 969 970 971
        if (
            not self._dist_context.has_annotation
            or not self._dist_context.process_meshes
        ):
972 973 974 975
            ranks = self._num_machines * self._num_devices_per_machine
            tensor_node = self._dist_context._serial_ordered_tensor_nodes[0]
            tensor_node_id = _node_id(tensor_node)
            tensor = self._dist_context._dist_tensors_for_graph[
976 977
                tensor_node_id
            ].serial_tensor
978
            tensor_dist_attr = self._dist_context._dist_tensors_for_graph[
979 980
                tensor_node_id
            ].dist_attr
981 982
            tensor_dist_attr.process_mesh = ProcessMesh(list(range(ranks)))
            self._dist_context._process_meshes.append(
983 984
                tensor_dist_attr.process_mesh
            )
985 986 987 988 989 990 991 992 993 994
            tensor_dist_attr.dims_mapping = [0] + [
                -1 for _ in range(len(tensor.shape) - 1)
            ]
            tensor_dist_attr.mark_annotated("process_mesh")
            tensor_dist_attr.mark_annotated("dims_mapping")
            print("Use dp as the init parallel strategy!", flush=True)

        # Do the sharding propagation
        self._completer.complete_forward_annotation()
        self._dist_context.block_state.parse_forward_blocks(
995 996
            self._dist_context.serial_main_program
        )
997 998 999

        # Backup the intital parallel strategy
        self._init_parallel_strategy[0] = copy.deepcopy(
1000 1001
            self._dist_context._dist_tensors_for_program
        )
1002
        self._init_parallel_strategy[1] = copy.deepcopy(
1003 1004
            self._dist_context._dist_ops_for_program
        )
1005
        self._init_parallel_strategy[2] = copy.deepcopy(
1006 1007
            self._dist_context.process_meshes
        )
1008 1009 1010

        # Initialize the best parallel strategy to the initial one
        self._best_parallel_strategy[0] = copy.deepcopy(
1011 1012
            self._dist_context._dist_tensors_for_program
        )
1013
        self._best_parallel_strategy[1] = copy.deepcopy(
1014 1015
            self._dist_context._dist_ops_for_program
        )
1016
        self._best_parallel_strategy[2] = copy.deepcopy(
1017 1018
            self._dist_context._process_meshes
        )
1019 1020 1021 1022 1023 1024 1025 1026

    def _store_best_parallel_strategy(self):
        # Swap the best and the current parallel strategy
        tmp = [None, None, None]
        tmp[0] = self._best_parallel_strategy[0]
        tmp[1] = self._best_parallel_strategy[1]
        tmp[2] = self._best_parallel_strategy[2]
        self._best_parallel_strategy[
1027 1028
            0
        ] = self._dist_context._dist_tensors_for_program
1029
        self._best_parallel_strategy[
1030 1031
            1
        ] = self._dist_context._dist_ops_for_program
1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046
        self._best_parallel_strategy[2] = self._dist_context._process_meshes
        self._dist_context._dist_tensors_for_program = tmp[0]
        self._dist_context._dist_ops_for_program = tmp[1]
        self._dist_context._process_meshes = tmp[2]

    def tune(self):
        global_start_time = time.time()
        self._dist_context._backup(serial=True, dist=True)
        # This store statement must follow the above backup statement
        self._store_init_parallel_strategy()
        init_time = self._estimate_trial()  # estimate_trial when init
        # print_program_with_dist_attr(self._dist_context.serial_main_program, self._dist_context)
        # We have to restore the distributed context, because the estimation of one trail need to
        # generate the backward and update parts. Since we will do the tuning process,
        # here we only need to reset all distributed information to the default one.
1047 1048 1049 1050 1051 1052
        self._dist_context._restore(
            serial=True,
            serial_mode="to_backup",
            dist=True,
            dist_mode="to_default",
        )
1053 1054 1055 1056 1057

        best_time = init_time
        start_time = time.time()
        self.construct_space()
        end_time = time.time()
1058 1059 1060 1061 1062 1063
        print(
            "construct_space time",
            self._num_trials,
            end_time - start_time,
            flush=True,
        )
1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
        create_trial_time = 0.0
        eval_trial_time = 0.0
        self._sample_time = 0.0
        self._complete_time = 0.0
        self._estimate_time = 0.0
        while True:
            start_time = time.time()
            trial = self._create_trial()
            if self._num_trials == 0:
                num_prev_trials = 0
            else:
                num_prev_trials = self._num_trials - 1
            end_time = time.time()
            cur_create_trial_time = end_time - start_time
1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088
            create_trial_time = (
                num_prev_trials * create_trial_time + cur_create_trial_time
            ) / self._num_trials
            print(
                "create_trial time",
                num_prev_trials,
                self._num_trials,
                create_trial_time,
                cur_create_trial_time,
                flush=True,
            )
1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099
            if trial.status == TrialStatus.STOPPED:
                break
            # We need to backup the distributed context, because the evaluation of one trail will
            # generate the backward and update parts which may change the context.
            # However, the distributed information of the context aren't backup since a new one is used.
            self._dist_context._backup(serial=True, dist=False)

            start_time = time.time()
            results = self._eval_trial(trial)
            end_time = time.time()
            cur_eval_trial_time = end_time - start_time
1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111
            eval_trial_time = (
                num_prev_trials * eval_trial_time + cur_eval_trial_time
            ) / self._num_trials
            print(
                "eval_trial time",
                num_prev_trials,
                self._num_trials,
                eval_trial_time,
                cur_eval_trial_time,
                "\n",
                flush=True,
            )
1112 1113 1114 1115 1116 1117 1118

            cur_time = results["estimate_time"]
            if cur_time < best_time:
                self._update_trail(trial, results)
                self._store_best_parallel_strategy()
                best_time = cur_time
            # We need to restore the distributed context and reset the distributed information to the default.
1119 1120 1121 1122 1123 1124
            self._dist_context._restore(
                serial=True,
                serial_mode="to_backup",
                dist=True,
                dist_mode="to_default",
            )
1125
        # Select the best parallel strategy
1126 1127 1128
        self._dist_context._dist_tensors_for_program = (
            self._best_parallel_strategy[0]
        )
1129
        self._dist_context._dist_ops_for_program = self._best_parallel_strategy[
1130 1131
            1
        ]
1132
        self._dist_context._process_meshes = self._best_parallel_strategy[2]