auto_parallel_save_load.py 11.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16
import random
17
import shutil
18 19 20
import unittest

import numpy as np
21 22 23 24

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
25 26
import paddle.static as static
import paddle.utils as utils
27
from paddle.distributed import fleet
28 29
from paddle.distributed.auto_parallel.utils import (
    load_checkpoint_into_program,
30
    save_distributed_checkpoint,
31
)
32 33
from paddle.distributed.fleet import auto
from paddle.fluid.initializer import NumpyArrayInitializer
34 35 36 37 38 39 40 41 42

paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
PP_MESH_0 = None
PP_MESH_1 = None


class MLPLayer(nn.Layer):
43 44 45
    def __init__(
        self, hidden_size=64, intermediate_size=4 * 64, initializer_range=0.02
    ):
46
        super().__init__()
47 48 49 50 51 52 53
        d_model = hidden_size
        dim_feedforward = intermediate_size
        np.random.seed(2021)
        arr = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
        weight_attr = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr))
        bias_attr = None

54 55 56 57 58 59
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
60 61 62 63
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, input):
        if _global_parallel_strategy == "pp":
64 65
            auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
            auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
66
        elif _global_parallel_strategy == "mp":
67 68 69 70 71 72
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, [None, "x"]
            )
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, ["x", None]
            )
73
        elif _global_parallel_strategy == "dp":
74 75 76 77 78 79
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, [None, None]
            )
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, None]
            )
80 81 82 83 84 85 86 87 88 89

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)

        return out


def mlp_forward(train_program, start_program):
90 91 92
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
93 94 95

        batch_size = 4
        hidden_size = 64
96 97 98 99 100 101
        input = static.data(
            name="input", shape=[batch_size, hidden_size], dtype='float32'
        )
        label = static.data(
            name="label", shape=[batch_size, 1], dtype='float32'
        )
102 103

        if _global_parallel_strategy == "pp":
104 105
            auto.shard_tensor(input, PP_MESH_0, [None, None])
            auto.shard_tensor(label, PP_MESH_1, [None, None])
106
        elif _global_parallel_strategy == "dp":
107
            auto.shard_tensor(input, _global_process_mesh, ["x", None])
108
        elif _global_parallel_strategy == "mp":
109
            auto.shard_tensor(input, _global_process_mesh, [None, None])
110

111 112 113 114 115
        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            initializer_range=0.02,
        )
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

        predict = mlp(input)
        error_cost = paddle.nn.functional.square_error_cost(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def get_distributed_program():
    train_program = static.Program()
    startup_program = static.Program()

    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.semi_auto = True
    fleet.init(is_collective=True, strategy=dist_strategy)

132 133 134
    loss, train_program, startup_program = mlp_forward(
        train_program, startup_program
    )
135 136 137 138

    optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
    optimizer = fleet.distributed_optimizer(optimizer)
    _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
139 140
        loss, startup_program
    )
141 142 143 144 145 146 147 148 149 150 151 152 153 154

    return dist_main_prog, dist_startup_prog, loss


class TestMLPSaveLoad(unittest.TestCase):
    def setUp(self):
        paddle.seed(2021)
        random.seed(2021)
        np.random.seed(2021)

    def test_mlp_dp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
        global _global_process_mesh
155
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
156 157 158 159 160 161 162 163 164 165 166 167

        dist_main_prog, dist_start_prog, loss = get_distributed_program()
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_dp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
168
                save_distributed_checkpoint(dist_main_prog, path, path)
169

170 171 172 173 174 175 176 177
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
178 179 180 181

        last_res = res[0]
        ckpt_path = [
            "./output_dp0/model_state_rank0.pdmodel",
182
            "./output_dp1/model_state_rank1.pdmodel",
183
        ]
184 185
        dist_attr_path = [
            "./output_dp0/dist_attr_rank0.pdattr",
186
            "./output_dp1/dist_attr_rank1.pdattr",
187 188
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
189
        for step in range(10, 20):
190 191 192 193 194 195 196 197
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
198 199 200 201 202 203 204 205

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_dp{}".format(paddle.distributed.get_rank()))

    def test_mlp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
        global _global_process_mesh
206
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
207 208 209 210 211 212 213 214 215 216 217 218 219

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_mp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
220
                save_distributed_checkpoint(dist_main_prog, path, path)
221

222 223 224 225 226 227 228 229
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
230 231 232 233

        last_res = res[0]
        ckpt_path = [
            "./output_mp0/model_state_rank0.pdmodel",
234
            "./output_mp1/model_state_rank1.pdmodel",
235
        ]
236 237
        dist_attr_path = [
            "./output_mp0/dist_attr_rank0.pdattr",
238
            "./output_mp1/dist_attr_rank1.pdattr",
239 240
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
241
        for step in range(10, 20):
242 243 244 245 246 247 248 249
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
250 251 252 253 254 255 256 257

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_mp{}".format(paddle.distributed.get_rank()))

    def test_mlp_pp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
        global _global_process_mesh
258
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
259
        global PP_MESH_0
260
        PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
261
        global PP_MESH_1
262
        PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
263 264 265 266 267 268 269 270 271 272 273 274 275

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_pp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
276
                save_distributed_checkpoint(dist_main_prog, path, path)
277 278

            if paddle.distributed.get_rank() in [0]:
279 280 281 282 283 284 285
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                )
286
            else:
287 288 289 290 291 292 293 294
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                    fetch_list=[loss],
                )
295 296 297 298 299 300

        if paddle.distributed.get_rank() in [1]:
            last_res = res[0]

        ckpt_path = [
            "./output_pp0/model_state_rank0.pdmodel",
301
            "./output_pp1/model_state_rank1.pdmodel",
302
        ]
303 304
        dist_attr_path = [
            "./output_pp0/dist_attr_rank0.pdattr",
305
            "./output_pp1/dist_attr_rank1.pdattr",
306 307
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
308 309
        for step in range(10, 20):
            if paddle.distributed.get_rank() in [0]:
310 311 312 313 314 315 316
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                )
317
            else:
318 319 320 321 322 323 324 325
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                    fetch_list=[loss],
                )
326 327 328 329 330 331 332 333

        if paddle.distributed.get_rank() in [1]:
            self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_pp{}".format(paddle.distributed.get_rank()))


if __name__ == "__main__":
    unittest.main()