auto_parallel_save_load.py 11.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest
import random
import numpy as np
import os
import shutil

import paddle
import paddle.nn as nn
import paddle.utils as utils
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto

from paddle.distributed import fleet
from paddle.fluid.initializer import NumpyArrayInitializer
32
from paddle.distributed.auto_parallel.utils import save_distributed_checkpoint, load_checkpoint_into_program
33 34 35 36 37 38 39 40 41

paddle.enable_static()
_global_parallel_strategy = None
_global_process_mesh = None
PP_MESH_0 = None
PP_MESH_1 = None


class MLPLayer(nn.Layer):
42

43 44 45 46 47 48 49 50 51 52 53 54
    def __init__(self,
                 hidden_size=64,
                 intermediate_size=4 * 64,
                 initializer_range=0.02):
        super(MLPLayer, self).__init__()
        d_model = hidden_size
        dim_feedforward = intermediate_size
        np.random.seed(2021)
        arr = np.random.normal(0, 0.02, size=(d_model, dim_feedforward))
        weight_attr = paddle.ParamAttr(initializer=NumpyArrayInitializer(arr))
        bias_attr = None

55 56 57 58 59 60 61 62
        self.linear0 = nn.Linear(d_model,
                                 dim_feedforward,
                                 weight_attr,
                                 bias_attr=bias_attr)
        self.linear1 = nn.Linear(dim_feedforward,
                                 d_model,
                                 weight_attr,
                                 bias_attr=bias_attr)
63 64 65 66
        self.norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, input):
        if _global_parallel_strategy == "pp":
67 68
            auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None])
            auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None])
69
        elif _global_parallel_strategy == "mp":
70 71 72 73
            auto.shard_tensor(self.linear0.weight, _global_process_mesh,
                              [None, "x"])
            auto.shard_tensor(self.linear1.weight, _global_process_mesh,
                              ["x", None])
74
        elif _global_parallel_strategy == "dp":
75 76 77 78
            auto.shard_tensor(self.linear0.weight, _global_process_mesh,
                              [None, None])
            auto.shard_tensor(self.linear1.weight, _global_process_mesh,
                              [None, None])
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

        out = self.norm(input)
        out = self.linear0(out)
        out = F.gelu(out, approximate=True)
        out = self.linear1(out)

        return out


def mlp_forward(train_program, start_program):
    with static.program_guard(train_program,start_program), \
        utils.unique_name.guard():

        batch_size = 4
        hidden_size = 64
94 95 96 97 98 99
        input = static.data(name="input",
                            shape=[batch_size, hidden_size],
                            dtype='float32')
        label = static.data(name="label",
                            shape=[batch_size, 1],
                            dtype='float32')
100 101

        if _global_parallel_strategy == "pp":
102 103
            auto.shard_tensor(input, PP_MESH_0, [None, None])
            auto.shard_tensor(label, PP_MESH_1, [None, None])
104
        elif _global_parallel_strategy == "dp":
105
            auto.shard_tensor(input, _global_process_mesh, ["x", None])
106
        elif _global_parallel_strategy == "mp":
107
            auto.shard_tensor(input, _global_process_mesh, [None, None])
108 109 110 111

        mlp = MLPLayer(hidden_size=hidden_size,
                       intermediate_size=4 * hidden_size,
                       initializer_range=0.02)
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139

        predict = mlp(input)
        error_cost = paddle.nn.functional.square_error_cost(predict, label)
        loss = paddle.mean(error_cost)

    return loss, train_program, start_program


def get_distributed_program():
    train_program = static.Program()
    startup_program = static.Program()

    dist_strategy = fleet.DistributedStrategy()
    dist_strategy.semi_auto = True
    fleet.init(is_collective=True, strategy=dist_strategy)

    loss, train_program, startup_program = mlp_forward(train_program,
                                                       startup_program)

    optimizer = paddle.fluid.optimizer.SGDOptimizer(learning_rate=0.01)
    optimizer = fleet.distributed_optimizer(optimizer)
    _, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
        loss, startup_program)

    return dist_main_prog, dist_startup_prog, loss


class TestMLPSaveLoad(unittest.TestCase):
140

141 142 143 144 145 146 147 148 149
    def setUp(self):
        paddle.seed(2021)
        random.seed(2021)
        np.random.seed(2021)

    def test_mlp_dp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "dp"
        global _global_process_mesh
150
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
151 152 153 154 155 156 157 158 159 160 161 162

        dist_main_prog, dist_start_prog, loss = get_distributed_program()
        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_dp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
163
                save_distributed_checkpoint(dist_main_prog, path, path)
164 165 166 167 168 169 170 171 172 173 174 175 176

            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        last_res = res[0]
        ckpt_path = [
            "./output_dp0/model_state_rank0.pdmodel",
            "./output_dp1/model_state_rank1.pdmodel"
        ]
177 178 179 180 181
        dist_attr_path = [
            "./output_dp0/dist_attr_rank0.pdattr",
            "./output_dp1/dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
        for step in range(10, 20):
            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_dp{}".format(paddle.distributed.get_rank()))

    def test_mlp_mp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "mp"
        global _global_process_mesh
197
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
198 199 200 201 202 203 204 205 206 207 208 209 210

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_mp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
211
                save_distributed_checkpoint(dist_main_prog, path, path)
212 213 214 215 216 217 218 219 220 221 222 223 224

            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        last_res = res[0]
        ckpt_path = [
            "./output_mp0/model_state_rank0.pdmodel",
            "./output_mp1/model_state_rank1.pdmodel"
        ]
225 226 227 228 229
        dist_attr_path = [
            "./output_mp0/dist_attr_rank0.pdattr",
            "./output_mp1/dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
        for step in range(10, 20):
            res = exe.run(dist_main_prog,
                          feed={
                              "input": input[step * 4:(step + 1) * 4, :],
                              "label": label[step * 4:(step + 1) * 4, :]
                          },
                          fetch_list=[loss])

        self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_mp{}".format(paddle.distributed.get_rank()))

    def test_mlp_pp(self):
        global _global_parallel_strategy
        _global_parallel_strategy = "pp"
        global _global_process_mesh
245
        _global_process_mesh = auto.ProcessMesh([0, 1], dim_names=["x"])
246
        global PP_MESH_0
247
        PP_MESH_0 = auto.ProcessMesh(mesh=[0], dim_names=["x"])
248
        global PP_MESH_1
249
        PP_MESH_1 = auto.ProcessMesh(mesh=[1], dim_names=["x"])
250 251 252 253 254 255 256 257 258 259 260 261 262

        dist_main_prog, dist_start_prog, loss = get_distributed_program()

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(dist_start_prog)

        input = np.random.random(size=(80, 64)).astype('float32')
        label = np.random.random(size=(80, 1)).astype('float32')
        for step in range(20):
            if step == 10:
                path = "./output_pp{}".format(paddle.distributed.get_rank())
                os.makedirs(path, exist_ok=True)
263
                save_distributed_checkpoint(dist_main_prog, path, path)
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285

            if paddle.distributed.get_rank() in [0]:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              })
            else:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              },
                              fetch_list=[loss])

        if paddle.distributed.get_rank() in [1]:
            last_res = res[0]

        ckpt_path = [
            "./output_pp0/model_state_rank0.pdmodel",
            "./output_pp1/model_state_rank1.pdmodel"
        ]
286 287 288 289 290
        dist_attr_path = [
            "./output_pp0/dist_attr_rank0.pdattr",
            "./output_pp1/dist_attr_rank1.pdattr"
        ]
        load_checkpoint_into_program(ckpt_path, dist_attr_path, dist_main_prog)
291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312
        for step in range(10, 20):
            if paddle.distributed.get_rank() in [0]:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              })
            else:
                res = exe.run(dist_main_prog,
                              feed={
                                  "input": input[step * 4:(step + 1) * 4, :],
                                  "label": label[step * 4:(step + 1) * 4, :]
                              },
                              fetch_list=[loss])

        if paddle.distributed.get_rank() in [1]:
            self.assertEqual(last_res, res[0])
        shutil.rmtree("./output_pp{}".format(paddle.distributed.get_rank()))


if __name__ == "__main__":
    unittest.main()