pipeline_parallel.py 13.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
17
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
18
from .parallel_layers.pp_layers import PipelineLayer
19 20 21

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
22
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
23
from ..utils.log_util import logger
24
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
S
ShenLiang 已提交
25
from .pp_utils import p2p_communication as p2p
26

27 28
__all__ = []

29 30 31

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
32 33 34
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
35 36 37
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1
38 39
        self.use_sharding_parallel = self._hcg.get_sharding_parallel_world_size(
        ) > 1
40 41 42 43 44 45 46 47

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

48 49
        self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape']

50 51
        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
52
        self.pp_group = self._hcg.get_pipe_parallel_group()
53

54
        p2p.initialize_p2p_groups(hcg, self._using_cache)
55

56 57
        _initialize_recompute_hcg(hcg)

58 59 60
        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()
61
        self.micro_batch_id = 0
62

63 64
        self._compute_loss = True

65 66 67 68 69 70 71
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

72 73 74 75
        if self.use_sharding_parallel:
            logger.info("start broadcast sharding parameters")
            broadcast_sharding_parameters(self._layers, self._hcg)

76
        if self.use_data_parallel:
77
            logger.info("start broadcast dp parameters")
78
            broadcast_dp_parameters(self._layers, self._hcg)
79

80 81 82 83
    def forward_backward_pipeline(self, data, scaler=None):
        # use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
84

85 86
        self.scaler = scaler

87 88
        # store data for train
        self.data = data
89

90 91 92
        # store total loss of entire batch
        self.total_loss = None

93 94
        # store data id for micro_batch
        self.micro_batch_id = 0
95

96 97 98
        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
99

100 101
        input_buffers = []
        output_buffers = []
102

103 104
        for step_id in range(startup_steps):
            input_tensor = p2p.recv_forward()
105

106 107
            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)
108

109 110
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
111

112 113
        if steady_steps > 0:
            input_tensor = p2p.recv_forward()
114

115 116
        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))
117

118
            output_tensor = self._forward_step(input_tensor)
119

120
            output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
121

122 123
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
124

125 126
            input_tensor, output_tensor = input_buffers.pop(
                0), output_buffers.pop(0)
127

128 129 130 131 132 133
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)

            if last_iter:
                input_tensor = None
                p2p.send_backward(input_tensor_grad)
134
            else:
135
                input_tensor = p2p.send_backward_recv_forward(input_tensor_grad)
136

137 138 139
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
140

141
            output_tensor_grad = p2p.recv_backward()
142

143 144 145
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)
            p2p.send_backward(input_tensor_grad)
146

147
        self._layers.allreduce_shared_weight_gradients()
148 149
        with paddle.amp.auto_cast(enable=False):
            train_loss = self._broadcast_final_loss()
150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
        return train_loss

    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')

        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
                "For the first and the last stage, the data must be set.")
        else:
            data = None

        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        self._layers.train()

        # 1f1b for pipeline
        train_loss = self.forward_backward_pipeline(data, scaler)
172 173

        # optimizer
174 175
        with paddle.amp.auto_cast(enable=False):
            self._optimizer_step()
176 177

        return train_loss
178

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
    def eval_batch(self, data, compute_loss=False):
        self._layers.eval()
        self._compute_loss = compute_loss

        # save data for eval
        self.data = data
        # store data id for micro_batch
        self.micro_batch_id = 0

        # store total loss of entire batch
        self.total_loss = None

        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps

        input_buffers = []
        output_buffers = []

        for step_id in range(startup_steps):
            input_tensor = p2p.recv_forward()

            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

        if steady_steps > 0:
            input_tensor = p2p.recv_forward()

        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))

            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)

            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)

            if not last_iter:
                input_tensor = p2p.recv_forward()

222 223 224 225 226 227
        if self._compute_loss:
            self.train_loss = self._broadcast_final_loss()
        else:
            self.train_loss = output_buffers

        return self.train_loss
228

229 230 231 232 233 234 235
    def _forward_step(self, input_tensor):
        if self.stage_id == 0:
            input_tensor = self._load_micro_batch(self.micro_batch_id)

        output_tensor = self._layers.forward(input_tensor)

        if self.is_last_stage:
236 237 238 239 240 241 242 243 244
            # train calculate loss for train
            if self._compute_loss:
                assert self._layers._loss_fn is not None, "loss function should exist to compute loss"
                labels = self._load_micro_batch(self.micro_batch_id)
                output_tensor = self._layers._loss_fn(output_tensor, labels)
                assert isinstance(
                    output_tensor, paddle.Tensor
                ), "Currently, loss_fn should obtain Paddle.Tensor dtype"

245 246 247
                with paddle.amp.auto_cast(enable=False):
                    if self.accumulate_steps > 1:
                        output_tensor = output_tensor / self.accumulate_steps
248

249 250 251
                    if self.total_loss is None:
                        self.total_loss = paddle.zeros_like(output_tensor)
                    self.total_loss += output_tensor.detach()
252 253 254 255 256 257 258 259 260 261 262

        self.micro_batch_id += 1
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
        if self.is_last_stage:
            assert output_tensor_grad is None
            if self.scaler:
                paddle.autograd.backward(self.scaler.scale(output_tensor))
            else:
                paddle.autograd.backward(output_tensor)
263
        else:
264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281
            if isinstance(output_tensor, tuple):
                outputs = [t for t in output_tensor if not t.stop_gradient]
                assert len(outputs) == len(output_tensor_grad)
                paddle.autograd.backward(
                    tensors=outputs,
                    grad_tensors=[t for t in output_tensor_grad])
            else:
                paddle.autograd.backward(
                    tensors=[output_tensor], grad_tensors=[output_tensor_grad])

        input_tensor_grad = None
        if input_tensor is not None:
            if isinstance(input_tensor, tuple):
                input_tensor_grad = tuple(
                    [t.grad for t in input_tensor if not t.stop_gradient])
            else:
                input_tensor_grad = input_tensor.grad
        return input_tensor_grad
282 283

    def _load_micro_batch(self, cache_id):
284 285 286 287 288 289 290
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
291 292 293
                assert len(
                    inputs[0]
                ) > 1, "If you use tuple for input data, it should have at least two inputs."
294 295 296 297 298 299
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
300 301
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
302 303 304
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
305
                return inputs[0][begin:end, :].detach()
306 307 308 309 310
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
311 312
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
313
            else:
314 315
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
316
                return inputs[1][begin:end, :].detach()
317 318 319
        else:
            # No data input is required for other stages
            inputs = None
320

321
    def _broadcast_final_loss(self):
322 323 324
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.detach()
325 326 327 328 329 330 331
            is_fp32 = paddle.to_tensor(
                1) if loss.dtype == paddle.float32 else paddle.to_tensor(0)
            paddle.distributed.broadcast(
                is_fp32,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
332 333 334 335 336
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
337
        else:
338 339 340 341 342 343 344 345 346 347
            is_fp32 = paddle.to_tensor(1)
            paddle.distributed.broadcast(
                is_fp32,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
            loss = paddle.zeros(
                shape=[1],
                dtype="float32") if is_fp32.numpy()[0] else paddle.zeros(
                    shape=[1], dtype="float16")
348 349 350 351 352 353
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
354

355
    def _optimizer_step(self):
356
        if self.scaler:
357
            self.scaler.step(self.optimizer)
S
ShenLiang 已提交
358
            self.scaler.update()
359 360
        else:
            self.optimizer.step()
361

362 363 364
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()