planner_v2.py 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16 17 18 19 20
import logging
import os
import pickle

import numpy as np

21 22
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.static.dist_attribute import (
23 24 25
    OperatorDistAttr,
    TensorDistAttr,
)
26 27 28 29
from paddle.distributed.auto_parallel.static.dist_op import DistributedOperator
from paddle.distributed.auto_parallel.static.dist_tensor import (
    DistributedTensor,
)
30

31
from ...utils.log_utils import get_logger
32 33
from .completion import Completer
from .dist_context import get_default_distributed_context
34
from .tuner.parallel_tuner import ParallelTuner
35
from .tuner.rule_based_tuner import RuleBasedTuner
36
from .utils import is_naive_data_parallel
37

38 39 40 41 42

class Planner:
    def __init__(self, mode, dist_context):
        self._mode = mode
        self._dist_context = dist_context
43
        self._load = False
44 45 46 47 48

        # NOTE: [HighOrderGrad]. There are grad ops in forward phase, and it need
        # dependency of backward-forward ops in forward completion.
        default_ctx = get_default_distributed_context()
        self._dist_context._dist_op_context = default_ctx.dist_op_context
49 50
        self._dist_context.data_parallel = default_ctx.data_parallel
        if not is_naive_data_parallel(self._dist_context):
51
            # Use SSA graph for complex parallism
52 53
            self._dist_context.initialize(with_graph=True)
        else:
54
            # Use program for data parallel parallism
55
            self._dist_context.initialize(with_graph=False)
56 57 58

        self._completer = Completer(self._dist_context)

59
        self._strategy = dist_context.strategy
60
        # set parallel tuner for auto search
61
        if self._strategy.auto_mode == "full_random":
62 63 64
            self._parallel_tuner = ParallelTuner(
                self._dist_context, mode=self._mode
            )
65 66 67 68
        elif self._strategy.auto_mode == "full_rule_based":
            self._parallel_tuner = RuleBasedTuner(
                self._dist_context, mode=self._mode
            )
69

70 71 72 73 74
    @property
    def completer(self):
        return self._completer

    def plan(self):
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
        logger = get_logger(logging.INFO)
        path = None
        if self._dist_context._json_config:
            try:
                path = self._dist_context._json_config["tuner_load_path"]
            except:
                path = None
        if path and os.path.exists(path):
            try:
                with open(path, "rb") as f:
                    dist_attrs = pickle.load(f)
                tensor_dist_attrs = dist_attrs["tensor"]
                op_dist_attrs = dist_attrs["op"]
                process_meshes = dist_attrs["process_meshes"]
                cluster = dist_attrs["cluster"]
                last_gpu_model = cluster.machines[0].devices[0].model
                last_gpu_memory = cluster.machines[0].devices[0].memory
                last_node_count = len(cluster.machines)
                last_device_count = len(cluster.get_all_devices("GPU"))

                gpu_model = (
                    self._dist_context.cluster.machines[0].devices[0].model
                )
                gpu_memory = (
                    self._dist_context.cluster.machines[0].devices[0].memory
                )
                node_count = len(self._dist_context.cluster.machines)
                device_count = len(
                    self._dist_context.cluster.get_all_devices("GPU")
                )
                if (
                    gpu_model != last_gpu_model
                    or gpu_memory != last_gpu_memory
                    or last_node_count != node_count
                    or device_count != last_device_count
                ):
                    logger.info(
                        "The cluster {} nodes {} {} devices is different from the saved last cluster {} nodes {} {} devices, so we run the planner again.".format(
                            node_count,
                            device_count,
                            gpu_model,
                            last_node_count,
                            last_device_count,
                            last_gpu_model,
                        )
                    )
                    need_set_dist_attr = False
                else:
                    need_set_dist_attr = True
            except:
                need_set_dist_attr = False

            if need_set_dist_attr:
                for key in op_dist_attrs:
                    serial_op = self._dist_context._dist_ops_for_program[
                        key
                    ].serial_op
                    # clear dist attr
                    serial_op.dist_attr = OperatorDistAttr(serial_op.desc)
                    serial_op.dist_attr.parse_from_string(op_dist_attrs[key])
                    self._dist_context._dist_ops_for_program[
                        key
                    ] = DistributedOperator(serial_op)

                for key in tensor_dist_attrs:
                    serial_tensor = (
                        self._dist_context._dist_tensors_for_program[
                            key
                        ].serial_tensor
                    )
                    # clear dist attr
                    serial_tensor.dist_attr = TensorDistAttr(serial_tensor.desc)
                    serial_tensor.dist_attr.parse_from_string(
                        tensor_dist_attrs[key]
                    )
                    self._dist_context._dist_tensors_for_program[
                        key
                    ] = DistributedTensor(serial_tensor)

                process_meshes = []
                for item in dist_attrs["process_meshes"]:
                    process_ids = item[0]
                    shape = item[1]
                    process_meshes.append(
                        ProcessMesh(
                            np.array(process_ids).reshape(shape).tolist()
                        )
                    )

                self._dist_context.process_meshes = process_meshes
                self._load = True

                logger.info(
                    f"The parallel strategy has been loaded from {path}"
                )

        if not self._load:
            if self._strategy.auto_mode != "semi":
                self._parallel_tuner.tune()
            else:
                self._completer.complete_forward_annotation()

        if os.getenv("PADDLE_AUTO_PARALLEL_STAGE", "run") != "run":
            quit()

180 181
        # parse forward sub block
        self._dist_context.block_state.parse_forward_blocks(
182 183
            self._dist_context.serial_main_program
        )