test_auto_parallel_cost_model.py 8.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

19
import copy
20 21 22 23 24 25
import paddle
import paddle.nn as nn
import paddle.static as static
import paddle.nn.functional as F
import paddle.utils as utils
import paddle.distributed.auto_parallel as auto
26
from paddle.distributed.auto_parallel.completion import Completer
27
from paddle.distributed.auto_parallel.dist_context import DistributedContext
28 29
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
30
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
31 32 33
from paddle.distributed.auto_parallel.reshard import reshard
from paddle.distributed.auto_parallel.cost_model import estimate_cost
import paddle.fluid.core as core
34
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
35 36 37

paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
38 39
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]])
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
NUM_RANKS = 8
STAGE_0_CNT = 5
STAGE_1_CNT = 10
pp_cfg = [[0, 1, 4, 5], [2, 3, 6, 7]]

device = "gpu" if core.is_compiled_with_cuda() else "cpu"


class MLPLayer(nn.Layer):
    def __init__(self,
                 hidden_size=256,
                 intermediate_size=4 * 256,
                 initializer_range=0.02,
                 is_distributed=True):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(initializer=nn.initializer.Normal(
            mean=0.0, std=initializer_range))
        bias_attr = None

        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr)
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr)
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

        self.is_distributed = is_distributed

    def forward(self, input):
        if self.is_distributed:
            auto.shard_tensor(
72 73 74
                self.linear0.weight,
                dist_attr={"process_mesh": PP_MESH_0,
                           "dims_mapping": [-1, 1]})
75
            auto.shard_tensor(
76 77 78
                self.linear1.weight,
                dist_attr={"process_mesh": PP_MESH_1,
                           "dims_mapping": [1, -1]})
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)

        return out


def get_single_node_data():
    train_program = paddle.static.Program()
    startup_program = paddle.static.Program()

    loss, train_program, startup_program = mlp_forward(
        train_program, startup_program, is_distributed=False)

    cost_model = core.CostModel()
    cost_data = cost_model.profile_measure(train_program, startup_program,
                                           device, ["time"])

    op_name2cost = [{}, {}]
    for idx, op in enumerate(train_program.blocks[0].ops):
        if idx <= STAGE_0_CNT:
            op_name2cost[0][op.type] = cost_data.get_op_time_ms(idx)
        elif idx <= STAGE_1_CNT:
            op_name2cost[1][op.type] = cost_data.get_op_time_ms(idx)
    return op_name2cost


def mlp_forward(train_program, start_program, is_distributed=True):
    with static.program_guard(train_program,
                              start_program), utils.unique_name.guard():
        batch_size = 4
        hidden_size = 256
        sequence_len = 128
        if is_distributed:
            input = static.data(
                name="input", shape=[batch_size, hidden_size], dtype='float32')
            label = static.data(
                name="label", shape=[batch_size, 1], dtype='float32')
        else:
            input = paddle.ones(
                name="input", shape=[batch_size, hidden_size], dtype='float32')
            label = paddle.ones(
                name="label", shape=[batch_size, 1], dtype='float32')

        if is_distributed:
126 127 128 129 130 131 132 133
            auto.shard_tensor(
                input,
                dist_attr={"process_mesh": PP_MESH_0,
                           "dims_mapping": [0, -1]})
            auto.shard_tensor(
                label,
                dist_attr={"process_mesh": PP_MESH_1,
                           "dims_mapping": [0, -1]})
134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151

        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            initializer_range=0.02,
            is_distributed=is_distributed)

        predict = mlp(input)
        error_cost = paddle.nn.functional.square_error_cost(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def get_dist_prog(train_program, startup_program, dist_context, rank_id):
    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

152 153 154 155
    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context
156

157
    # serial forward & backward completion
158 159 160
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
        train_program)
161 162 163 164 165 166 167 168 169

    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None)

170
    # logical partition
171 172 173 174 175 176
    partitioner = Partitioner(dist_context, rank_id)
    auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition(
        complete_train_program, startup_program, params_grads)

    partitioned_optimize_ops = parallelizer._apply_optimize(
        auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads)
177

178
    return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199


def check_runtime_estimation(cost):
    return cost.runtime > 0


def check_memory_estimation(cost):
    for i in range(NUM_RANKS):
        if cost.static_mem[i] <= 0 or cost.peak_mem[i] <= 0:
            return False
        if cost.static_mem[i] > cost.peak_mem[i]:
            return False
    return True


def check_empty_program_runtime(cost):
    return cost.runtime == 0


def check_empty_program_memory(cost):
    for mem in cost.peak_mem:
200
        if mem > 1:
201 202
            return False
    for mem in cost.static_mem:
203
        if mem > 1:
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
            return False
    return True


class TestCostModel(unittest.TestCase):
    def test_empty_program_cost_model(self):
        empty_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        standalone_cost_data = [{}]
        empty_pp_cfg = None
        cluster = None
        cost = estimate_cost(
            [empty_program],
            cluster=cluster,
            pipeline_config=empty_pp_cfg,
            standalone_cost_data=standalone_cost_data,
            batch_size=1)

        self.assertTrue(check_empty_program_runtime(cost))
        self.assertTrue(check_empty_program_memory(cost))

    def test_auto_parallel_cost_model(self):
        standalone_cost_data = get_single_node_data()
227
        dist_program = []
228
        for rank_id in range(NUM_RANKS):
229 230 231
            train_program = paddle.static.Program()
            startup_program = paddle.static.Program()
            dist_context = DistributedContext()
232
            distributed_program, dist_startup_prog, dist_params_grads = get_dist_prog(
233 234
                train_program, startup_program, dist_context, rank_id)
            reshard(distributed_program, dist_startup_prog, rank_id,
235
                    dist_context, dist_params_grads)
236
            dist_program.append(distributed_program)
237 238
        cluster = None
        cost = estimate_cost(
239
            dist_program,
240 241 242 243 244 245 246 247 248 249
            cluster=cluster,
            pipeline_config=pp_cfg,
            standalone_cost_data=standalone_cost_data,
            batch_size=4)
        self.assertTrue(check_runtime_estimation(cost))
        self.assertTrue(check_memory_estimation(cost))


if __name__ == "__main__":
    unittest.main()