test_base_cost.py 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
16
import os
17
import tempfile
18 19 20
import unittest

from test_cluster import cluster_json
21 22 23

import paddle
import paddle.nn.functional as F
24
from paddle import nn, static, utils
25 26
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.cluster import Cluster
27 28 29 30
from paddle.distributed.auto_parallel.completion import Completer
from paddle.distributed.auto_parallel.cost import (
    AllreduceSumOpCost,
    _g_op_cost_factory,
31 32 33
)
from paddle.distributed.auto_parallel.cost.base_cost import (
    build_comm_costs_from_descs,
34
    build_comm_desc_from_dist_op,
35
    build_comp_costs_from_descs,
36 37
    build_comp_desc_from_dist_op,
    build_dp_costs,
38
)
39 40 41
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer
from paddle.distributed.fleet import auto
42 43 44

paddle.enable_static()
_global_parallel_strategy = "dp_mp_pp"
45 46 47
_global_process_mesh = auto.ProcessMesh(
    [[[0, 1], [4, 5]], [[2, 3], [6, 7]]], dim_names=["x", "y", "z"]
)
48 49
PP_MESH_0 = auto.ProcessMesh([[0, 1], [4, 5]], dim_names=["x", "y"])
PP_MESH_1 = auto.ProcessMesh([[2, 3], [6, 7]], dim_names=["x", "y"])
50 51 52


class MLPLayer(nn.Layer):
53 54 55 56 57 58
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        initializer_range=0.02,
    ):
59
        super().__init__()
60 61 62
        d_model = hidden_size
        dim_feedforward = intermediate_size
        weight_attr = paddle.ParamAttr(
63 64
            initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
        )
65 66
        bias_attr = None

67 68 69 70 71 72
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
73 74 75
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, input):
76 77
        auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, "y"])
        auto.shard_tensor(self.linear1.weight, PP_MESH_1, ["y", None])
78 79 80 81 82 83 84 85 86 87

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)

        return out


def mlp_forward(train_program, start_program):
88 89 90
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
91 92 93
        batch_size = 4
        hidden_size = 1024
        sequence_len = 512
94 95 96 97 98 99
        input = static.data(
            name="input", shape=[batch_size, hidden_size], dtype='float32'
        )
        label = static.data(
            name="label", shape=[batch_size, 1], dtype='float32'
        )
100 101

        fill_constant_out = paddle.fluid.layers.fill_constant_batch_size_like(
102 103
            input=input, shape=[batch_size], value=1, dtype="int32"
        )
104 105 106
        embedding = paddle.nn.Embedding(10, hidden_size, sparse=True)
        embedding_out = embedding(fill_constant_out)

107 108
        auto.shard_tensor(input, PP_MESH_0, ["x", None])
        auto.shard_tensor(label, PP_MESH_1, ["x", None])
109

110 111 112 113 114
        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            initializer_range=0.02,
        )
115 116 117 118 119 120 121 122 123 124 125

        predict = mlp(embedding_out)
        error_cost = paddle.nn.functional.square_error_cost(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def get_prog(train_program, startup_program, dist_context, rank_id):
    global _global_process_mesh
    dist_context.process_mesh = _global_process_mesh
126 127 128
    loss, train_program, startup_program = mlp_forward(
        train_program, startup_program
    )
129 130 131 132 133 134 135 136 137

    fleet._user_defined_strategy = fleet.DistributedStrategy()
    fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer()
    parallelizer = AutoParallelizer(fleet)
    parallelizer._dist_context = dist_context

    # serial forward & backward completion
    completer = Completer(dist_context)
    complete_train_program = completer.complete_forward_annotation(
138 139
        train_program
    )
140
    dist_context.block_state.parse_forward_blocks(complete_train_program)
141 142 143 144 145 146 147 148
    params_grads = parallelizer._generate_backward(
        complete_train_program,
        startup_program,
        loss,
        parameter_list=None,
        no_grad_set=None,
        callbacks=None,
    )
149 150 151 152 153 154 155 156 157 158 159 160
    return train_program, startup_program, params_grads


class TestBaseCost(unittest.TestCase):
    def setUp(self):
        self.temp_dir = tempfile.TemporaryDirectory()

    def tearDown(self):
        self.temp_dir.cleanup()

    def test_base_cost(self):
        # Build cluster
161 162 163
        cluster_json_path = os.path.join(
            self.temp_dir.name, "auto_parallel_cluster.json"
        )
164 165 166 167 168 169 170 171 172 173 174
        cluster_json_object = json.loads(cluster_json)
        with open(cluster_json_path, "w") as cluster_json_file:
            json.dump(cluster_json_object, cluster_json_file)
        cluster = Cluster()
        cluster.build_from_file(cluster_json_path)

        train_program = paddle.static.Program()
        startup_program = paddle.static.Program()
        dist_context = DistributedContext()
        rank_id = 2
        train_program, startup_program, params_grads = get_prog(
175 176
            train_program, startup_program, dist_context, rank_id
        )
177 178 179 180

        for op in train_program.global_block().ops:
            dist_op = dist_context.get_dist_op_for_program(op)
            if dist_op:
181
                processes = dist_op.dist_attr.process_mesh.process_ids
182 183 184 185 186
                comp_descs = build_comp_desc_from_dist_op(dist_op, dist_context)
                self.assertTrue(isinstance(comp_descs, dict) and comp_descs)
                var_names = None
                if op.input_arg_names:
                    var_names = op.input_arg_names[0]
187 188 189 190 191 192 193 194 195
                    comm_descs = build_comm_desc_from_dist_op(
                        "c_allreduce_sum",
                        dist_op,
                        dist_context,
                        var_names,
                        attrs=None,
                        parallel_axis=0,
                        group_ranks=None,
                    )
196 197 198 199 200 201 202 203
                    self.assertTrue(isinstance(comm_descs, dict) and comm_descs)
                    comm_descs = build_comm_desc_from_dist_op(
                        "c_allreduce_sum",
                        dist_op,
                        dist_context,
                        var_names,
                        attrs=None,
                        parallel_axis=None,
204 205
                        group_ranks=processes,
                    )
206 207 208
                    self.assertTrue(isinstance(comm_descs, dict) and comm_descs)

                    comm_costs = build_comm_costs_from_descs(
209 210 211 212 213 214
                        AllreduceSumOpCost,
                        dist_context,
                        processes,
                        comm_descs,
                        cluster,
                    )
215 216 217
                    self.assertTrue(comm_costs)

                    comp_costs = build_comp_costs_from_descs(
218 219 220 221 222 223
                        _g_op_cost_factory[op.type],
                        dist_context,
                        processes,
                        comp_descs,
                        cluster,
                    )
224 225 226
                    self.assertTrue(comp_costs)

                    result = []
227 228 229 230 231 232 233 234 235
                    build_dp_costs(
                        result,
                        dist_op,
                        dist_context,
                        var_names[0],
                        None,
                        0,
                        cluster,
                    )
236 237 238 239 240 241 242 243 244
                    self.assertTrue(result)

        # Remove unnecessary files
        if os.path.exists(cluster_json_path):
            os.remove(cluster_json_path)


if __name__ == "__main__":
    unittest.main()