auto_parallel_data_unshard.py 7.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import unittest

import copy
import numpy as np
import random

import paddle
import paddle.nn as nn
import paddle.fluid.core as core
import paddle.distributed.auto_parallel as auto
import paddle.nn.functional as F
from paddle.distributed import fleet

paddle.enable_static()
paddle.distributed.init_parallel_env()


class TestDataUnshard(unittest.TestCase):
35

36
    def test_dp2pp1mp1(self):
37

38 39 40
        def create_model(train_program, start_program):
            with paddle.static.program_guard(train_program, start_program):

41
                MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
42 43 44 45
                input = paddle.static.data(name='input', shape=[2, 8])
                label = paddle.static.data(name='label', shape=[2, 8])

                weight_attr = paddle.ParamAttr(
46
                    initializer=nn.initializer.Normal(mean=0.0, std=0.02))
47 48 49
                linear0 = nn.Linear(8, 8, weight_attr)
                linear1 = nn.Linear(8, 8, weight_attr)

50 51 52 53
                auto.shard_tensor(input, MESH_0, ["x", None])
                auto.shard_tensor(label, MESH_0, ["x", None])
                auto.shard_tensor(linear0.weight, MESH_0, [None, None])
                auto.shard_tensor(linear1.weight, MESH_0, [None, None])
54 55 56 57

                linear0_out = linear0(input)
                gelu_out = F.gelu(linear0_out)
                linear1_out = linear1(gelu_out)
58 59
                error_cost = paddle.nn.functional.square_error_cost(
                    linear1_out, label)
60 61 62 63 64 65 66 67 68 69 70 71
                loss = paddle.mean(error_cost)
                return train_program, start_program, loss, input, label

        train_program = paddle.static.Program()
        start_program = paddle.static.Program()
        # serial program
        train_program, start_program, loss, input, label = create_model(
            train_program, start_program)

        dist_strategy = fleet.DistributedStrategy()
        dist_strategy.semi_auto = True
        fleet.init(is_collective=True, strategy=dist_strategy)
72 73 74 75 76
        optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
                                                         beta1=0.9,
                                                         beta2=0.999,
                                                         epsilon=1e-08,
                                                         grad_clip=None)
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

        optimizer = fleet.distributed_optimizer(optimizer)
        _, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
            loss, start_program)

        worker_index = paddle.distributed.get_rank()
        paddle.seed(worker_index + 2021)
        random.seed(worker_index + 2021)
        np.random.seed(worker_index + 2021)

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(distributed_startup_program)

        input_data = np.array(range(2 * 8)).reshape([2, 8]).astype("float32")
        label_data = np.random.randint(0, 10, [2, 8]).astype("float32")

94 95 96
        fetchs = [loss.name, 'split@RESHARD.tmp_0'] if worker_index == 0 else [
            loss.name, 'split@RESHARD.tmp_1'
        ]
97 98 99 100 101 102
        loss_np, shard_data_np = exe.run(distributed_main_program,
                                         feed={
                                             "input": input_data,
                                             "label": label_data
                                         },
                                         fetch_list=fetchs)
103 104 105 106
        desired = input_data[worker_index].reshape(shard_data_np.shape)
        np.testing.assert_allclose(shard_data_np, desired)

    def dp1pp1mp2(self):
107

108 109 110
        def create_model(train_program, start_program):
            with paddle.static.program_guard(train_program, start_program):

111
                MESH_0 = auto.ProcessMesh([0, 1], dim_names=["x"])
112 113 114 115
                input = paddle.static.data(name='input', shape=[8, 8])
                label = paddle.static.data(name='label', shape=[8, 8])

                weight_attr = paddle.ParamAttr(
116
                    initializer=nn.initializer.Normal(mean=0.0, std=0.02))
117 118 119
                linear0 = nn.Linear(8, 8, weight_attr)
                linear1 = nn.Linear(8, 8, weight_attr)

120 121 122 123
                auto.shard_tensor(input, MESH_0, [None, None])
                auto.shard_tensor(label, MESH_0, [None, None])
                auto.shard_tensor(linear0.weight, MESH_0, [None, "x"])
                auto.shard_tensor(linear1.weight, MESH_0, ["x", None])
124 125 126 127 128 129

                linear0_out = linear0(input)
                gelu_out = F.gelu(linear0_out)

                linear1_out = linear1(gelu_out)

130 131
                error_cost = paddle.nn.functional.square_error_cost(
                    linear1_out, label)
132 133 134 135 136 137 138 139 140 141 142 143
                loss = paddle.mean(error_cost)
                return train_program, start_program, loss, input, label

        train_program = paddle.static.Program()
        start_program = paddle.static.Program()
        # serial program
        train_program, start_program, loss, input, label = create_model(
            train_program, start_program)

        dist_strategy = fleet.DistributedStrategy()
        dist_strategy.semi_auto = True
        fleet.init(is_collective=True, strategy=dist_strategy)
144 145 146 147 148
        optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
                                                         beta1=0.9,
                                                         beta2=0.999,
                                                         epsilon=1e-08,
                                                         grad_clip=None)
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166

        optimizer = fleet.distributed_optimizer(optimizer)
        _, _, distributed_startup_program, distributed_main_program = optimizer.minimize(
            loss, start_program)

        worker_index = paddle.distributed.get_rank()
        paddle.seed(worker_index + 2021)
        random.seed(worker_index + 2021)
        np.random.seed(worker_index + 2021)

        place = paddle.set_device("gpu")
        exe = paddle.static.Executor(place)
        exe.run(distributed_startup_program)

        input_data = np.array(range(8 * 8)).reshape([8, 8]).astype("float32")
        label_data = np.random.randint(0, 10, [8, 8]).astype("float32")

        fetchs = [loss.name, 'input']
167 168 169 170 171 172
        loss_np, shard_data_np = exe.run(distributed_main_program,
                                         feed={
                                             "input": input_data,
                                             "label": label_data
                                         },
                                         fetch_list=fetchs)
173 174 175 176 177 178 179

        desired = input_data.reshape(shard_data_np.shape)
        np.testing.assert_allclose(shard_data_np, desired)


if __name__ == "__main__":
    unittest.main()