auto_parallel_data_unshard.py 8.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import copy
import numpy as np
import random

import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.distributed.auto_parallel as auto
import paddle.nn.functional as F
from paddle.distributed import fleet

paddle.enable_static()
paddle.distributed.init_parallel_env()


class TestDataUnshard(unittest.TestCase):
35

36
    def test_dp2pp1mp1(self):
37

38 39 40
        def create_model(train_program, start_program):
            with paddle.static.program_guard(train_program, start_program):

41
                MESH_0 = auto.ProcessMesh([0, 1])
42 43 44 45
                input = paddle.static.data(name='input', shape=[2, 8])
                label = paddle.static.data(name='label', shape=[2, 8])

                weight_attr = paddle.ParamAttr(
46
                    initializer=nn.initializer.Normal(mean=0.0, std=0.02))
47 48 49
                linear0 = nn.Linear(8, 8, weight_attr)
                linear1 = nn.Linear(8, 8, weight_attr)

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69
                auto.shard_tensor(input,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [0, -1]
                                  })
                auto.shard_tensor(label,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [0, -1]
                                  })
                auto.shard_tensor(linear0.weight,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [-1, -1]
                                  })
                auto.shard_tensor(linear1.weight,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [-1, -1]
                                  })
70 71 72 73

                linear0_out = linear0(input)
                gelu_out = F.gelu(linear0_out)
                linear1_out = linear1(gelu_out)
74 75
                error_cost = paddle.nn.functional.square_error_cost(
                    linear1_out, label)
76 77 78 79 80 81 82 83 84 85 86 87
                loss = paddle.mean(error_cost)
                return train_program, start_program, loss, input, label

        train_program = paddle.static.Program()
        start_program = paddle.static.Program()
        # serial program
        train_program, start_program, loss, input, label = create_model(
            train_program, start_program)

        dist_strategy = fleet.DistributedStrategy()
        dist_strategy.semi_auto = True
        fleet.init(is_collective=True, strategy=dist_strategy)
88 89 90 91 92
        optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
                                                         beta1=0.9,
                                                         beta2=0.999,
                                                         epsilon=1e-08,
                                                         grad_clip=None)
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109

        optimizer = fleet.distributed_optimizer(optimizer)
        _, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
            loss, start_program)

        worker_index = paddle.distributed.get_rank()
        paddle.seed(worker_index + 2021)
        random.seed(worker_index + 2021)
        np.random.seed(worker_index + 2021)

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(distributed_startup_program)

        input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
        label_data = np.random.randint(0, 10, [2, 8]).astype("float32")

110 111 112
        fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [
            loss.name, 'split@RESHARD.tmp_1'
        ]
113 114 115 116 117 118
        loss_np, shard_data_np = exe.run(distributed_main_program,
                                         feed={
                                             "input": input_data,
                                             "label": label_data
                                         },
                                         fetch_list=fetchs)
119 120 121 122
        desired = input_data[worker_index].reshape(shard_data_np.shape)
        np.testing.assert_allclose(shard_data_np, desired)

    def dp1pp1mp2(self):
123

124 125 126
        def create_model(train_program, start_program):
            with paddle.static.program_guard(train_program, start_program):

127
                MESH_0 = auto.ProcessMesh([0, 1])
128 129 130 131
                input = paddle.static.data(name='input', shape=[8, 8])
                label = paddle.static.data(name='label', shape=[8, 8])

                weight_attr = paddle.ParamAttr(
132
                    initializer=nn.initializer.Normal(mean=0.0, std=0.02))
133 134 135
                linear0 = nn.Linear(8, 8, weight_attr)
                linear1 = nn.Linear(8, 8, weight_attr)

136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
                auto.shard_tensor(input,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [-1, -1]
                                  })
                auto.shard_tensor(label,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [-1, -1]
                                  })

                auto.shard_tensor(linear0.weight,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [-1, 0]
                                  })
                auto.shard_tensor(linear1.weight,
                                  dist_attr={
                                      "process_mesh": MESH_0,
                                      "dims_mapping": [0, -1]
                                  })
157 158 159 160 161 162

                linear0_out = linear0(input)
                gelu_out = F.gelu(linear0_out)

                linear1_out = linear1(gelu_out)

163 164
                error_cost = paddle.nn.functional.square_error_cost(
                    linear1_out, label)
165 166 167 168 169 170 171 172 173 174 175 176
                loss = paddle.mean(error_cost)
                return train_program, start_program, loss, input, label

        train_program = paddle.static.Program()
        start_program = paddle.static.Program()
        # serial program
        train_program, start_program, loss, input, label = create_model(
            train_program, start_program)

        dist_strategy = fleet.DistributedStrategy()
        dist_strategy.semi_auto = True
        fleet.init(is_collective=True, strategy=dist_strategy)
177 178 179 180 181
        optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
                                                         beta1=0.9,
                                                         beta2=0.999,
                                                         epsilon=1e-08,
                                                         grad_clip=None)
182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199

        optimizer = fleet.distributed_optimizer(optimizer)
        _, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
            loss, start_program)

        worker_index = paddle.distributed.get_rank()
        paddle.seed(worker_index + 2021)
        random.seed(worker_index + 2021)
        np.random.seed(worker_index + 2021)

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(distributed_startup_program)

        input_data = np.array(range(8 * 8)).reshape([8, 8]).astype("float32")
        label_data = np.random.randint(0, 10, [8, 8]).astype("float32")

        fetchs = [loss.name, 'input']
200 201 202 203 204 205
        loss_np, shard_data_np = exe.run(distributed_main_program,
                                         feed={
                                             "input": input_data,
                                             "label": label_data
                                         },
                                         fetch_list=fetchs)
206 207 208 209 210 211 212

        desired = input_data.reshape(shard_data_np.shape)
        np.testing.assert_allclose(shard_data_np, desired)


if __name__ == "__main__":
    unittest.main()