hybrid_parallel_pp_layer.py 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import unittest
16

17
import numpy as np
18

19
import paddle.nn.functional as F
20
from paddle import nn
21 22
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_parallel import LayerDesc, PipelineLayer
23
from paddle.nn import Layer, Sequential
24 25 26 27


class ReshapeHelp(Layer):
    def __init__(self, shape):
28
        super().__init__()
29 30 31 32
        self.shape = shape

    def forward(self, x):
        return x.reshape(shape=self.shape)
33 34 35 36


class AlexNet(Layer):
    def __init__(self, num_classes=10):
37
        super().__init__()
38
        self.features = Sequential(
39
            nn.Conv2D(1, 64, kernel_size=11, stride=4, padding=5),
40
            nn.ReLU(),
41 42
            nn.MaxPool2D(kernel_size=2, stride=2),
            nn.Conv2D(64, 192, kernel_size=5, padding=2),
43
            nn.ReLU(),
44 45
            nn.MaxPool2D(kernel_size=2, stride=2),
            nn.Conv2D(192, 384, kernel_size=3, padding=1),
46
            nn.ReLU(),
47
            nn.Conv2D(384, 256, kernel_size=3, padding=1),
48
            nn.ReLU(),
49
            nn.Conv2D(256, 256, kernel_size=3, padding=1),
50
            nn.ReLU(),
51 52
            nn.MaxPool2D(kernel_size=2, stride=2),
        )
53 54

        self.reshape_layer = ReshapeHelp(shape=[-1, 256])
55 56 57 58 59
        self.classifier = nn.Linear(256, num_classes)
        self.loss_fn = nn.loss.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.features(x)
60
        x = self.reshape_layer(x)
61 62 63 64 65 66 67
        x = self.classifier(x)
        return self.loss_fn(x, y)


class AlexNetPipe(AlexNet):
    def to_layers(self):
        feat = [self.features[i] for i in range(len(self.features))]
68
        loss_fn = [self.reshape_layer, self.classifier]
69 70 71 72 73 74 75 76
        feat.extend(loss_fn)
        return feat


class AlexNetPipeDesc(PipelineLayer):
    def __init__(self, num_classes=10, **kwargs):
        self.num_classes = num_classes
        decs = [
77
            LayerDesc(nn.Conv2D, 1, 64, kernel_size=11, stride=4, padding=5),
78
            LayerDesc(nn.ReLU),
79 80
            LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
            LayerDesc(nn.Conv2D, 64, 192, kernel_size=5, padding=2),
81
            F.relu,
82 83
            LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
            LayerDesc(nn.Conv2D, 192, 384, kernel_size=3, padding=1),
84
            F.relu,
85
            LayerDesc(nn.Conv2D, 384, 256, kernel_size=3, padding=1),
86
            F.relu,
87
            LayerDesc(nn.Conv2D, 256, 256, kernel_size=3, padding=1),
88
            F.relu,
89 90
            LayerDesc(nn.MaxPool2D, kernel_size=2, stride=2),
            LayerDesc(ReshapeHelp, shape=[-1, 256]),
91 92
            LayerDesc(nn.Linear, 256, self.num_classes),  # classifier
        ]
93
        super().__init__(layers=decs, loss_fn=nn.CrossEntropyLoss(), **kwargs)
94 95 96 97 98


class TestPipeLayerAPI(unittest.TestCase):
    def setUp(self):
        strategy = fleet.DistributedStrategy()
99
        self.pipeline_parallel_size = 2
100 101 102
        strategy.hybrid_configs = {
            "dp_degree": 1,
            "mp_degree": 1,
103
            "pp_degree": self.pipeline_parallel_size,
104 105 106 107 108
        }
        fleet.init(is_collective=True, strategy=strategy)
        self.hcg = fleet.get_hybrid_communicate_group()

    def test_pipelayer_desc(self):
109
        pipe_model = AlexNetPipeDesc(num_stages=self.pipeline_parallel_size)
110 111 112 113
        np.testing.assert_array_equal(len(pipe_model.parameters()), 6)

    def test_pipelayer_sequential(self):
        init_net = AlexNetPipe()
114 115 116 117 118
        pipe_model = PipelineLayer(
            layers=init_net.to_layers(),
            num_stages=self.pipeline_parallel_size,
            loss_fn=nn.CrossEntropyLoss(),
        )
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
        stage_id = self.hcg.get_stage_id()
        init_parameters = init_net.parameters()
        pipe_parameters = pipe_model.parameters()
        part_number = len(init_parameters) // 2

        if stage_id == 0:
            for idx in range(part_number):
                param_a = init_parameters[idx]
                param_b = pipe_parameters[idx]
                np.testing.assert_array_equal(param_a.name, param_b.name)
                np.testing.assert_allclose(param_a.numpy(), param_b.numpy())

        elif stage_id == 1:
            for idx in range(part_number):
                param_a = init_parameters[idx + part_number]
                param_b = pipe_parameters[idx]

                np.testing.assert_array_equal(param_a.name, param_b.name)
                np.testing.assert_allclose(param_a.numpy(), param_b.numpy())

S
ShenLiang 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152
    def test_pipelayer_segment_method(self):
        init_net = AlexNetPipe()
        pipe_model = PipelineLayer(
            layers=init_net.to_layers(),
            num_stages=self.pipeline_parallel_size,
            seg_method=[0, 4],
            loss_fn=nn.CrossEntropyLoss(),
        )
        stage_id = self.hcg.get_stage_id()
        if stage_id == 0:
            np.testing.assert_array_equal(len(pipe_model.parameters()), 4)
        elif stage_id == 1:
            np.testing.assert_array_equal(len(pipe_model.parameters()), 8)

153 154 155

if __name__ == '__main__':
    unittest.main()