pipeline_parallel.py 13.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
17
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
18
from .parallel_layers.pp_layers import PipelineLayer
19 20 21

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
22
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
23
from ..utils.log_util import logger
24
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
S
ShenLiang 已提交
25
from .pp_utils import p2p_communication as p2p
26
import paddle.fluid.core as core
27

28 29
__all__ = []

30 31 32

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
33 34 35
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
36 37 38
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
39 40
        self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
        ) > 1
41 42 43 44 45 46 47 48

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

49 50
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

51 52
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
53
        self.pp_group = self._hcg.get_pipe_parallel_group()
54

55
        p2p.initialize_p2p_groups(hcg, self._using_cache)
56

57 58
        _initialize_recompute_hcg(hcg)

59 60 61
        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()
62
        self.micro_batch_id = 0
63

64 65
        self._compute_loss = True

66 67 68 69 70 71 72
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

73 74 75 76
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

77
        if self.use_data_parallel:
78
            logger.info("start broadcast dp parameters")
79
            broadcast_dp_parameters(self._layers, self._hcg)
80

81 82 83 84
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
85

86 87
        self.scaler = scaler

88 89
        # store data for train
        self.data = data
90

91 92 93
        # store total loss of entire batch
        self.total_loss = None

94 95
        # store data id for micro_batch
        self.micro_batch_id = 0
96

97 98 99
        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
100

101 102
        input_buffers = []
        output_buffers = []
103

104 105
        for step_id in range(startup_steps):
            input_tensor = p2p.recv_forward()
106

107 108
            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)
109

110 111
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
112

113 114
        if steady_steps > 0:
            input_tensor = p2p.recv_forward()
115

116 117
        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))
118

119
            output_tensor = self._forward_step(input_tensor)
120

121
            output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
122

123 124
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
125

126 127
            input_tensor, output_tensor = input_buffers.pop(
                0), output_buffers.pop(0)
128

129 130 131 132 133 134
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)

            if last_iter:
                input_tensor = None
                p2p.send_backward(input_tensor_grad)
135
            else:
136
                input_tensor = p2p.send_backward_recv_forward(input_tensor_grad)
137

138 139 140
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
141

142
            output_tensor_grad = p2p.recv_backward()
143

144 145 146
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)
            p2p.send_backward(input_tensor_grad)
147

148
        self._layers.allreduce_shared_weight_gradients()
149 150
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')

        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
                "For the first and the last stage, the data must be set.")
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

        # 1f1b for pipeline
        train_loss = self.forward_backward_pipeline(data, scaler)
173 174

        # optimizer
175 176
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
177 178

        return train_loss
179

180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222
    def eval_batch(self, data, compute_loss=False):
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
            input_tensor = p2p.recv_forward()

            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
            input_tensor = p2p.recv_forward()

        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))

            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
                input_tensor = p2p.recv_forward()

223 224 225 226 227 228
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
229

230 231 232 233 234 235 236
    def _forward_step(self, input_tensor):
        if self.stage_id == 0:
            input_tensor = self._load_micro_batch(self.micro_batch_id)

        output_tensor = self._layers.forward(input_tensor)

        if self.is_last_stage:
237 238 239 240 241
            # train calculate loss for train
            if self._compute_loss:
                assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
242 243 244
                assert isinstance(output_tensor, (
                    paddle.Tensor, core.eager.Tensor
                )), "Currently, loss_fn should obtain Paddle.Tensor dtype"
245

246 247 248
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
249

250 251 252
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
253 254 255 256 257

        self.micro_batch_id += 1
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
258 259 260 261 262 263 264
        with paddle.amp.auto_cast(enable=False):
            if self.is_last_stage:
                assert output_tensor_grad is None
                if self.scaler:
                    paddle.autograd.backward(self.scaler.scale(output_tensor))
                else:
                    paddle.autograd.backward(output_tensor)
265
            else:
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284
                if isinstance(output_tensor, tuple):
                    outputs = [t for t in output_tensor if not t.stop_gradient]
                    assert len(outputs) == len(output_tensor_grad)
                    paddle.autograd.backward(
                        tensors=outputs,
                        grad_tensors=[t for t in output_tensor_grad])
                else:
                    paddle.autograd.backward(
                        tensors=[output_tensor],
                        grad_tensors=[output_tensor_grad])

            input_tensor_grad = None
            if input_tensor is not None:
                if isinstance(input_tensor, tuple):
                    input_tensor_grad = tuple(
                        [t.grad for t in input_tensor if not t.stop_gradient])
                else:
                    input_tensor_grad = input_tensor.grad
            return input_tensor_grad
285 286

    def _load_micro_batch(self, cache_id):
287 288 289 290 291 292 293
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
294 295 296
                assert len(
                    inputs[0]
                ) > 1, "If you use tuple for input data, it should have at least two inputs."
297 298 299 300 301 302
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
303 304
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
305 306 307
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
308
                return inputs[0][begin:end, :].detach()
309 310 311 312 313
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
314 315
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
316
            else:
317 318
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
319
                return inputs[1][begin:end, :].detach()
320 321 322
        else:
            # No data input is required for other stages
            inputs = None
323

324
    def _broadcast_final_loss(self):
325 326 327
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.detach()
328 329 330 331 332 333 334
            is_fp32 = paddle.to_tensor(
                1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
            paddle.distributed.broadcast(
                is_fp32,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
335 336 337 338 339
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
340
        else:
341 342 343 344 345 346 347 348 349 350
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
            loss = paddle.zeros(
                shape=[1],
                dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
                    shape=[1], dtype="float16")
351 352 353 354 355 356
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
357

358
    def _optimizer_step(self):
359
        if self.scaler:
360
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
361
            self.scaler.update()
362 363
        else:
            self.optimizer.step()
364

365 366 367
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()