test_dist_context.py 7.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import copy
16
import unittest
17 18

import numpy as np
19 20

import paddle
21
import paddle.nn.functional as F
22
from paddle import nn, static
23 24
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.dist_context import DistributedContext
25
from paddle.distributed.fleet import auto
26 27 28 29 30 31

paddle.enable_static()

batch_size = 4
hidden_size = 1024
sequence_len = 512
32 33
_g_process_mesh = [
    auto.ProcessMesh([0, 1], dim_names=["x"]),
34
    auto.ProcessMesh([2, 3], dim_names=["x"]),
35
]
36 37 38 39 40 41 42 43 44 45 46 47 48


def get_random_inputs_and_labels(input_shape, label_shape):
    input = np.random.random(size=input_shape).astype('float32')
    label = np.random.random(size=label_shape).astype('float32')
    return input, label


def batch_generator_creator():
    def __reader__():
        for _ in range(batch_size):
            batch_input, batch_label = get_random_inputs_and_labels(
                [batch_size, sequence_len, hidden_size],
49 50
                [batch_size, sequence_len, 1],
            )
51 52 53 54 55 56
            yield batch_input, batch_label

    return __reader__


class MLPLayer(nn.Layer):
57 58 59 60 61 62 63
    def __init__(
        self,
        hidden_size=1024,
        intermediate_size=4 * 1024,
        dropout_ratio=0.1,
        initializer_range=0.02,
    ):
64
        super().__init__()
65 66
        d_model = hidden_size
        dim_feedforward = intermediate_size
67 68 69
        param_initializer = nn.initializer.Normal(
            mean=0.0, std=initializer_range
        )
70 71 72 73 74 75

        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
        self.linear0 = nn.Linear(
            d_model,
            dim_feedforward,
            weight_attr=paddle.ParamAttr(initializer=param_initializer),
76 77
            bias_attr=None,
        )
78 79 80 81
        self.linear1 = nn.Linear(
            dim_feedforward,
            d_model,
            weight_attr=paddle.ParamAttr(initializer=param_initializer),
82 83
            bias_attr=None,
        )
84 85 86

    def forward(self, input):
        out = self.norm(input)
87
        auto.shard_tensor(self.linear0.weight, _g_process_mesh[0], [None, "x"])
88 89
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
90
        auto.shard_tensor(self.linear1.weight, _g_process_mesh[1], ["x", None])
91 92 93 94 95 96 97 98 99 100 101 102 103 104
        out = self.linear1(out)

        return out


def get_program():
    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.semi_auto = True
    # fleet.init(is_collective=True, strategy=dist_strategy)

    train_program = static.Program()
    start_program = static.Program()
    with static.program_guard(train_program, start_program):
        # input
105 106 107 108 109 110 111 112
        input = static.data(
            name="input",
            shape=[batch_size, sequence_len, hidden_size],
            dtype='float32',
        )
        label = static.data(
            name="label", shape=[batch_size, sequence_len, 1], dtype='float32'
        )
113 114
        data_holder = [input, label]
        # dataloader
115 116 117 118 119 120
        dataloader = paddle.io.DataLoader.from_generator(
            feed_list=data_holder, capacity=4 * batch_size, iterable=False
        )
        dataloader.set_batch_generator(
            batch_generator_creator(), places=paddle.static.cuda_places()
        )
121
        # data dist_attr
122 123
        auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None])
        auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None])
124

125 126 127 128 129 130
        mlp_start = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
131 132
        pred = mlp_start(input)

133 134 135 136 137 138
        mlp_mid = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
139 140
        pred = mlp_mid(pred)

141 142 143 144 145 146
        mlp_end = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            dropout_ratio=0.1,
            initializer_range=0.02,
        )
147 148 149 150 151
        pred = mlp_end(pred)

        error_cost = paddle.nn.functional.square_error_cost(pred, label)
        loss = paddle.mean(error_cost)

152 153 154 155 156 157 158
        optimizer = paddle.optimizer.Adam(
            learning_rate=0.00001,
            beta1=0.9,
            beta2=0.999,
            epsilon=1e-08,
            grad_clip=None,
        )
159 160 161 162

        feed_vars = {"inputs": [input], "labels": [label]}
        fetch_vars = {"loss": [loss]}

163 164 165 166 167 168 169 170 171
    return (
        train_program,
        start_program,
        dataloader,
        loss,
        optimizer,
        feed_vars,
        fetch_vars,
    )
172 173 174 175


class TestDistributedContext(unittest.TestCase):
    def test_backup_restore(self):
176 177 178 179 180 181 182 183 184 185 186
        (
            train_program,
            start_program,
            dataloader,
            loss,
            optimizer,
            feed_vars,
            fetch_vars,
        ) = get_program()
        dist_context = DistributedContext(
            train_program, start_program, optimizer, loss, feed_vars, fetch_vars
187 188 189 190
        )
        dist_context.initialize()

        dist_context._backup(serial=True, dist=True)
191 192 193 194 195 196
        dist_context._restore(
            serial=True,
            serial_mode="to_backup",
            dist=True,
            dist_mode="to_backup",
        )
197 198

        dist_context._backup(serial=True, dist=True)
199 200 201 202 203 204
        dist_context._restore(
            serial=True,
            serial_mode="to_original",
            dist=True,
            dist_mode="to_original",
        )
205 206 207 208 209 210 211

        dist_context._backup(serial=True, dist=True)
        dist_context._restore(serial=True, dist=True, dist_mode="to_default")

        dist_context._backup(serial=True, dist=True)
        dist_context._restore(serial=True, dist=True, dist_mode="to_nothing")

212
    def test_deepcopy(self):
213 214 215 216 217 218 219 220 221 222 223
        (
            train_program,
            start_program,
            dataloader,
            loss,
            optimizer,
            feed_vars,
            fetch_vars,
        ) = get_program()
        dist_context = DistributedContext(
            train_program, start_program, optimizer, loss, feed_vars, fetch_vars
224 225 226 227 228 229
        )
        dist_context.initialize()

        copy_dist_context = copy.deepcopy(dist_context)

        copy_list = [
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249
            "_original_serial_main_program",
            "_original_serial_startup_program",
            "_serial_main_program",
            "_serial_startup_program",
            "_serial_graph",
            "_dist_main_programs",
            "_dist_startup_programs",
            "_serial_ordered_nodes",
            "_serial_ordered_tensor_nodes",
            "_serial_ordered_op_nodes",
            "_original_serial_loss",
            "_original_serial_feed_vars",
            "_original_serial_fetch_vars",
            "_serial_loss",
            "_serial_feed_vars",
            "_serial_fetch_vars",
            "_serial_optimizer",
            "_backup_serial_main_program_stack",
            "_backup_serial_startup_program_stack",
            "_pass_context",
250
            "_tensor_nodes_with_same_name",
251
        ]
252 253 254 255 256 257

        for i in range(len(copy_list)):
            copy_obj = "copy_dist_context." + copy_list[i]
            obj = "dist_context." + copy_list[i]
            assert id(eval(copy_obj)) == id(eval(obj))

258 259 260

if __name__ == "__main__":
    unittest.main()