auto_parallel_save_load.py 11.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
16
import random
17
import shutil
18 19 20
import unittest

import numpy as np
21 22 23

import paddle
import paddle.nn.functional as F
24
from paddle import nn, static, utils
25
from paddle.distributed import fleet
26 27
from paddle.distributed.auto_parallel.utils import (
    load_checkpoint_into_program,
28
    save_distributed_checkpoint,
29
)
30
from paddle.distributed.fleet import auto
31 32 33 34 35 36 37 38 39

paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
PP_MESH_0 = None
PP_MESH_1 = None


class MLPLayer(nn.Layer):
40 41 42
    def __init__(
        self, hidden_size=64, intermediate_size=4 * 64, initializer_range=0.02
    ):
43
        super().__init__()
44 45 46 47
        d_model = hidden_size
        dim_feedforward = intermediate_size
        np.random.seed(2021)
        arr = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
48 49 50
        weight_attr = paddle.ParamAttr(
            initializer=paddle.nn.initializer.Assign(arr)
        )
51 52
        bias_attr = None

53 54 55 56 57 58
        self.linear0 = nn.Linear(
            d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
        )
        self.linear1 = nn.Linear(
            dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
        )
59 60 61 62
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, input):
        if _global_parallel_strategy == "pp":
63 64
            auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
            auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
65
        elif _global_parallel_strategy == "mp":
66 67 68 69 70 71
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, [None, "x"]
            )
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, ["x", None]
            )
72
        elif _global_parallel_strategy == "dp":
73 74 75 76 77 78
            auto.shard_tensor(
                self.linear0.weight, _global_process_mesh, [None, None]
            )
            auto.shard_tensor(
                self.linear1.weight, _global_process_mesh, [None, None]
            )
79 80 81 82 83 84 85 86 87 88

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)

        return out


def mlp_forward(train_program, start_program):
89 90 91
    with static.program_guard(
        train_program, start_program
    ), utils.unique_name.guard():
92 93 94

        batch_size = 4
        hidden_size = 64
95 96 97 98 99 100
        input = static.data(
            name="input", shape=[batch_size, hidden_size], dtype='float32'
        )
        label = static.data(
            name="label", shape=[batch_size, 1], dtype='float32'
        )
101 102

        if _global_parallel_strategy == "pp":
103 104
            auto.shard_tensor(input, PP_MESH_0, [None, None])
            auto.shard_tensor(label, PP_MESH_1, [None, None])
105
        elif _global_parallel_strategy == "dp":
106
            auto.shard_tensor(input, _global_process_mesh, ["x", None])
107
        elif _global_parallel_strategy == "mp":
108
            auto.shard_tensor(input, _global_process_mesh, [None, None])
109

110 111 112 113 114
        mlp = MLPLayer(
            hidden_size=hidden_size,
            intermediate_size=4 * hidden_size,
            initializer_range=0.02,
        )
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

        predict = mlp(input)
        error_cost = paddle.nn.functional.square_error_cost(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def get_distributed_program():
    train_program = static.Program()
    startup_program = static.Program()

    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.semi_auto = True
    fleet.init(is_collective=True, strategy=dist_strategy)

131 132 133
    loss, train_program, startup_program = mlp_forward(
        train_program, startup_program
    )
134 135 136 137

    optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
    optimizer = fleet.distributed_optimizer(optimizer)
    _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
138 139
        loss, startup_program
    )
140 141 142 143 144 145 146 147 148 149 150 151 152 153

    return dist_main_prog, dist_startup_prog, loss


class TestMLPSaveLoad(unittest.TestCase):
    def setUp(self):
        paddle.seed(2021)
        random.seed(2021)
        np.random.seed(2021)

    def test_mlp_dp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
        global _global_process_mesh
154
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
155 156 157 158 159 160 161 162 163 164 165 166

        dist_main_prog, dist_start_prog, loss = get_distributed_program()
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_dp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
167
                save_distributed_checkpoint(dist_main_prog, path, path)
168

169 170 171 172 173 174 175 176
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
177 178 179 180

        last_res = res[0]
        ckpt_path = [
            "./output_dp0/model_state_rank0.pdmodel",
181
            "./output_dp1/model_state_rank1.pdmodel",
182
        ]
183 184
        dist_attr_path = [
            "./output_dp0/dist_attr_rank0.pdattr",
185
            "./output_dp1/dist_attr_rank1.pdattr",
186 187
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
188
        for step in range(10, 20):
189 190 191 192 193 194 195 196
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
197 198 199 200 201 202 203 204

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_dp{}".format(paddle.distributed.get_rank()))

    def test_mlp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
        global _global_process_mesh
205
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
206 207 208 209 210 211 212 213 214 215 216 217 218

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_mp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
219
                save_distributed_checkpoint(dist_main_prog, path, path)
220

221 222 223 224 225 226 227 228
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
229 230 231 232

        last_res = res[0]
        ckpt_path = [
            "./output_mp0/model_state_rank0.pdmodel",
233
            "./output_mp1/model_state_rank1.pdmodel",
234
        ]
235 236
        dist_attr_path = [
            "./output_mp0/dist_attr_rank0.pdattr",
237
            "./output_mp1/dist_attr_rank1.pdattr",
238 239
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
240
        for step in range(10, 20):
241 242 243 244 245 246 247 248
            res = exe.run(
                dist_main_prog,
                feed={
                    "input": input[step * 4 : (step + 1) * 4, :],
                    "label": label[step * 4 : (step + 1) * 4, :],
                },
                fetch_list=[loss],
            )
249 250 251 252 253 254 255 256

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_mp{}".format(paddle.distributed.get_rank()))

    def test_mlp_pp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
        global _global_process_mesh
257
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
258
        global PP_MESH_0
259
        PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
260
        global PP_MESH_1
261
        PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
262 263 264 265 266 267 268 269 270 271 272 273 274

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_pp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
275
                save_distributed_checkpoint(dist_main_prog, path, path)
276 277

            if paddle.distributed.get_rank() in [0]:
278 279 280 281 282 283 284
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                )
285
            else:
286 287 288 289 290 291 292 293
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                    fetch_list=[loss],
                )
294 295 296 297 298 299

        if paddle.distributed.get_rank() in [1]:
            last_res = res[0]

        ckpt_path = [
            "./output_pp0/model_state_rank0.pdmodel",
300
            "./output_pp1/model_state_rank1.pdmodel",
301
        ]
302 303
        dist_attr_path = [
            "./output_pp0/dist_attr_rank0.pdattr",
304
            "./output_pp1/dist_attr_rank1.pdattr",
305 306
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
307 308
        for step in range(10, 20):
            if paddle.distributed.get_rank() in [0]:
309 310 311 312 313 314 315
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                )
316
            else:
317 318 319 320 321 322 323 324
                res = exe.run(
                    dist_main_prog,
                    feed={
                        "input": input[step * 4 : (step + 1) * 4, :],
                        "label": label[step * 4 : (step + 1) * 4, :],
                    },
                    fetch_list=[loss],
                )
325 326 327 328 329 330 331 332

        if paddle.distributed.get_rank() in [1]:
            self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_pp{}".format(paddle.distributed.get_rank()))


if __name__ == "__main__":
    unittest.main()