auto_parallel_save_load.py 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import random
import numpy as np
import os
import shutil

import paddle
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
26
from paddle.distributed.fleet import auto
27 28 29

from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer
30
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_checkpoint_into_program
31 32 33 34 35 36 37 38 39

paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
PP_MESH_0 = None
PP_MESH_1 = None


class MLPLayer(nn.Layer):
40

41 42 43 44 45 46 47 48 49 50 51 52
    def __init__(self,
                 hidden_size=64,
                 intermediate_size=4 * 64,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
        np.random.seed(2021)
        arr = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
        weight_attr = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr))
        bias_attr = None

53 54 55 56 57 58 59 60
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
61 62 63 64
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, input):
        if _global_parallel_strategy == "pp":
65 66
            auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
            auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
67
        elif _global_parallel_strategy == "mp":
68 69 70 71
            auto.shard_tensor(self.linear0.weight, _global_process_mesh,
                              [None, "x"])
            auto.shard_tensor(self.linear1.weight, _global_process_mesh,
                              ["x", None])
72
        elif _global_parallel_strategy == "dp":
73 74 75 76
            auto.shard_tensor(self.linear0.weight, _global_process_mesh,
                              [None, None])
            auto.shard_tensor(self.linear1.weight, _global_process_mesh,
                              [None, None])
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)

        return out


def mlp_forward(train_program, start_program):
    with static.program_guard(train_program,start_program), \
        utils.unique_name.guard():

        batch_size = 4
        hidden_size = 64
92 93 94 95 96 97
        input = static.data(name="input",
                            shape=[batch_size, hidden_size],
                            dtype='float32')
        label = static.data(name="label",
                            shape=[batch_size, 1],
                            dtype='float32')
98 99

        if _global_parallel_strategy == "pp":
100 101
            auto.shard_tensor(input, PP_MESH_0, [None, None])
            auto.shard_tensor(label, PP_MESH_1, [None, None])
102
        elif _global_parallel_strategy == "dp":
103
            auto.shard_tensor(input, _global_process_mesh, ["x", None])
104
        elif _global_parallel_strategy == "mp":
105
            auto.shard_tensor(input, _global_process_mesh, [None, None])
106 107 108 109

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       initializer_range=0.02)
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137

        predict = mlp(input)
        error_cost = paddle.nn.functional.square_error_cost(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def get_distributed_program():
    train_program = static.Program()
    startup_program = static.Program()

    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.semi_auto = True
    fleet.init(is_collective=True, strategy=dist_strategy)

    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
    optimizer = fleet.distributed_optimizer(optimizer)
    _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
        loss, startup_program)

    return dist_main_prog, dist_startup_prog, loss


class TestMLPSaveLoad(unittest.TestCase):
138

139 140 141 142 143 144 145 146 147
    def setUp(self):
        paddle.seed(2021)
        random.seed(2021)
        np.random.seed(2021)

    def test_mlp_dp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
        global _global_process_mesh
148
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
149 150 151 152 153 154 155 156 157 158 159 160

        dist_main_prog, dist_start_prog, loss = get_distributed_program()
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_dp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
161
                save_distributed_checkpoint(dist_main_prog, path, path)
162 163 164 165 166 167 168 169 170 171 172 173 174

            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        last_res = res[0]
        ckpt_path = [
            "./output_dp0/model_state_rank0.pdmodel",
            "./output_dp1/model_state_rank1.pdmodel"
        ]
175 176 177 178 179
        dist_attr_path = [
            "./output_dp0/dist_attr_rank0.pdattr",
            "./output_dp1/dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
        for step in range(10, 20):
            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_dp{}".format(paddle.distributed.get_rank()))

    def test_mlp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
        global _global_process_mesh
195
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
196 197 198 199 200 201 202 203 204 205 206 207 208

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_mp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
209
                save_distributed_checkpoint(dist_main_prog, path, path)
210 211 212 213 214 215 216 217 218 219 220 221 222

            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        last_res = res[0]
        ckpt_path = [
            "./output_mp0/model_state_rank0.pdmodel",
            "./output_mp1/model_state_rank1.pdmodel"
        ]
223 224 225 226 227
        dist_attr_path = [
            "./output_mp0/dist_attr_rank0.pdattr",
            "./output_mp1/dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242
        for step in range(10, 20):
            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_mp{}".format(paddle.distributed.get_rank()))

    def test_mlp_pp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
        global _global_process_mesh
243
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
244
        global PP_MESH_0
245
        PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
246
        global PP_MESH_1
247
        PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
248 249 250 251 252 253 254 255 256 257 258 259 260

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_pp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
261
                save_distributed_checkpoint(dist_main_prog, path, path)
262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283

            if paddle.distributed.get_rank() in [0]:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              })
            else:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              },
                              fetch_list=[loss])

        if paddle.distributed.get_rank() in [1]:
            last_res = res[0]

        ckpt_path = [
            "./output_pp0/model_state_rank0.pdmodel",
            "./output_pp1/model_state_rank1.pdmodel"
        ]
284 285 286 287 288
        dist_attr_path = [
            "./output_pp0/dist_attr_rank0.pdattr",
            "./output_pp1/dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
        for step in range(10, 20):
            if paddle.distributed.get_rank() in [0]:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              })
            else:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              },
                              fetch_list=[loss])

        if paddle.distributed.get_rank() in [1]:
            self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_pp{}".format(paddle.distributed.get_rank()))


if __name__ == "__main__":
    unittest.main()