pipeline_parallel.py 10.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and

import paddle
import paddle.fluid as fluid
from .meta_parallel_base import MetaParallelBase
17
from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg
18
from .parallel_layers.pp_layers import PipelineLayer
19 20 21

from ..utils.hybrid_parallel_util import broadcast_mp_parameters
from ..utils.hybrid_parallel_util import broadcast_dp_parameters
22
from ..utils.log_util import logger
23
from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler
S
ShenLiang 已提交
24
from .pp_utils import p2p_communication as p2p
25

26 27
__all__ = []

28 29 30

class PipelineParallel(MetaParallelBase):
    def __init__(self, layers, hcg, strategy):
31 32 33
        if not isinstance(layers, PipelineLayer):
            raise TypeError(
                "The Layer should be a derived class of PipelineLayer.")
34 35 36 37 38 39 40 41 42 43 44 45 46
        super(PipelineParallel, self).__init__(layers, hcg, strategy)
        self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1
        self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1

        self.total_loss = None

        self.micro_batch_size = self._strategy.pipeline_configs[
            'micro_batch_size']
        self.accumulate_steps = self._strategy.pipeline_configs[
            'accumulate_steps']

        self.num_stages = self._hcg.get_pipe_parallel_world_size()
        self.stage_id = self._hcg.get_stage_id()
47
        self.pp_group = self._hcg.get_pipe_parallel_group()
48

S
ShenLiang 已提交
49
        p2p.initialize_p2p_groups(hcg)
50

51 52
        _initialize_recompute_hcg(hcg)

53 54 55
        self.is_first_stage = self.stage_id == 0
        self.is_last_stage = (self.stage_id == (self.num_stages - 1))
        self.global_rank = self._hcg.get_global_rank()
56
        self.micro_batch_id = 0
57

58 59 60 61 62 63 64 65
        logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format(
            self.num_stages, self.stage_id))

        if self.use_model_parallel:
            logger.info("start broadcast mp parameters")
            broadcast_mp_parameters(self._layers, self._hcg)

        if self.use_data_parallel:
66
            logger.info("start broadcast dp parameters")
67
            broadcast_dp_parameters(self._layers, self._hcg)
68

69
    def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None):
70 71
        assert isinstance(optimizer, HybridParallelOptimizer), (
            'optimizer should be HybridParallelOptimizer subclass.')
72 73 74
        if scaler is not None:
            assert isinstance(scaler, HybridParallelGradScaler), (
                'scaler should be HybridParallelGradScaler subclass or None.')
75 76 77
        assert fluid.framework._dygraph_tracer()._has_grad, (
            'Please enable the generation of gradients.')

78 79
        if self.is_first_stage or self.is_last_stage:
            assert data is not None, (
80
                "For the first and the last stage, the data must be set.")
81
        else:
82 83
            data = None

84 85 86
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler
        self.scaler = scaler
87
        self.data = data
88

89 90
        self._layers.train()

91 92 93
        # store total loss of entire batch
        self.total_loss = None

94 95
        # store data id for micro_batch
        self.micro_batch_id = 0
96

97 98 99
        # Next, use the 1f1b scheduling strategy.
        # this strategy is inspired by:
        # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
100

101 102 103
        startup_steps = (self.num_stages - self.stage_id - 1)
        startup_steps = min(startup_steps, self.accumulate_steps)
        steady_steps = self.accumulate_steps - startup_steps
104

105 106
        input_buffers = []
        output_buffers = []
107

108 109
        for step_id in range(startup_steps):
            input_tensor = p2p.recv_forward()
110

111 112
            output_tensor = self._forward_step(input_tensor)
            p2p.send_forward(output_tensor)
113

114 115
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
116

117 118
        if steady_steps > 0:
            input_tensor = p2p.recv_forward()
119

120 121
        for i in range(steady_steps):
            last_iter = (i == (steady_steps - 1))
122

123
            output_tensor = self._forward_step(input_tensor)
124

125
            output_tensor_grad = p2p.send_forward_recv_backward(output_tensor)
126

127 128
            input_buffers.append(input_tensor)
            output_buffers.append(output_tensor)
129

130 131
            input_tensor, output_tensor = input_buffers.pop(
                0), output_buffers.pop(0)
132

133 134 135 136 137 138
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)

            if last_iter:
                input_tensor = None
                p2p.send_backward(input_tensor_grad)
139
            else:
140
                input_tensor = p2p.send_backward_recv_forward(input_tensor_grad)
141

142 143 144
        for i in range(startup_steps):
            input_tensor = input_buffers.pop(0)
            output_tensor = output_buffers.pop(0)
145

146
            output_tensor_grad = p2p.recv_backward()
147

148 149 150
            input_tensor_grad = self._backward_step(input_tensor, output_tensor,
                                                    output_tensor_grad)
            p2p.send_backward(input_tensor_grad)
151

152
        self._layers.allreduce_shared_weight_gradients()
153

154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
        self.train_loss = self._reduce_final_loss()

        # optimizer
        self._optimizer_step()
        return self.train_loss

    def _forward_step(self, input_tensor):
        if self.stage_id == 0:
            input_tensor = self._load_micro_batch(self.micro_batch_id)

        output_tensor = self._layers.forward(input_tensor)

        if self.is_last_stage:
            labels = self._load_micro_batch(self.micro_batch_id)
            output_tensor = self._layers._loss_fn(output_tensor, labels)
            assert isinstance(
                output_tensor, paddle.
                Tensor), "Currently, loss_fn should obtain Paddle.Tensor dtype"

            if self.accumulate_steps > 1:
                output_tensor = output_tensor / self.accumulate_steps

            if self.total_loss is None:
                self.total_loss = paddle.zeros_like(output_tensor)
            self.total_loss += output_tensor.detach()

        self.micro_batch_id += 1
        return output_tensor

    def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
        if self.is_last_stage:
            assert output_tensor_grad is None
            if self.scaler:
                paddle.autograd.backward(self.scaler.scale(output_tensor))
            else:
                paddle.autograd.backward(output_tensor)
190
        else:
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
            if isinstance(output_tensor, tuple):
                outputs = [t for t in output_tensor if not t.stop_gradient]
                assert len(outputs) == len(output_tensor_grad)
                paddle.autograd.backward(
                    tensors=outputs,
                    grad_tensors=[t for t in output_tensor_grad])
            else:
                paddle.autograd.backward(
                    tensors=[output_tensor], grad_tensors=[output_tensor_grad])

        input_tensor_grad = None
        if input_tensor is not None:
            if isinstance(input_tensor, tuple):
                input_tensor_grad = tuple(
                    [t.grad for t in input_tensor if not t.stop_gradient])
            else:
                input_tensor_grad = input_tensor.grad
        return input_tensor_grad
209 210

    def _load_micro_batch(self, cache_id):
211 212 213 214 215 216 217
        inputs = self.data
        begin = cache_id * self.micro_batch_size
        end = begin + self.micro_batch_size

        if self.is_first_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[0], tuple):
218 219 220
                assert len(
                    inputs[0]
                ) > 1, "If you use tuple for input data, it should have at least two inputs."
221 222 223 224 225 226
                batch_size = inputs[0][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size, (
                    "batch_size needs to be divisible by micro_batch_size. Currently, "
                    "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d."
                    %
                    (batch_size, self.micro_batch_size, self.accumulate_steps))
227 228
                data = [input[begin:end, :].detach() for input in inputs[0]]
                return tuple(data)
229 230 231
            else:
                batch_size = inputs[0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
232
                return inputs[0][begin:end, :].detach()
233 234 235 236 237
        elif self.is_last_stage:
            assert len(inputs) == 2, "length of input should be 2"
            if isinstance(inputs[1], tuple):
                batch_size = inputs[1][0].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
238 239
                data = [input[begin:end, :].detach() for input in inputs[1]]
                return tuple(data)
240
            else:
241 242
                batch_size = inputs[1].shape[0]
                assert self.micro_batch_size * self.accumulate_steps == batch_size
243
                return inputs[1][begin:end, :].detach()
244 245 246
        else:
            # No data input is required for other stages
            inputs = None
247

248 249 250 251 252 253 254 255 256
    def _reduce_final_loss(self):
        if self.is_last_stage:
            assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss"
            loss = self.total_loss.detach()
            paddle.distributed.broadcast(
                loss,
                src=self.global_rank,
                use_calc_stream=True,
                group=self.pp_group)
257
        else:
258 259 260 261 262 263 264
            loss = paddle.zeros(shape=[1], dtype="float32")
            paddle.distributed.broadcast(
                loss,
                src=self._hcg.get_rank_from_stage(self.num_stages - 1),
                use_calc_stream=True,
                group=self.pp_group)
        return loss
265

266
    def _optimizer_step(self):
267 268 269 270
        if self.scaler:
            self.scaler.minimize(self.optimizer, self.train_loss)
        else:
            self.optimizer.step()
271

272 273 274
        self.optimizer.clear_grad()
        if self.lr_scheduler:
            self.lr_scheduler.step()